In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.fft import fft, ifft
# import cv2
import pywt

def calculate_snr(original, watermarked):
    """Calculate Signal-to-Noise Ratio in dB."""
    if np.array_equal(original, watermarked):
        return float('inf')  # If signals are identical, SNR is infinite
    
    noise = original - watermarked
    signal_power = np.sum(original ** 2)
    noise_power = np.sum(noise ** 2)
    
    if noise_power == 0:
        return float('inf')  # No noise
    
    snr = 10 * np.log10(signal_power / noise_power)
    return snr

def calculate_nr(original, watermarked):
    """Calculate Noise Ratio."""
    if np.array_equal(original, watermarked):
        return 0.0  # If signals are identical, NR is zero
    
    noise = original - watermarked
    noise_power = np.sum(noise ** 2)
    signal_power = np.sum(original ** 2)
    
    nr = noise_power / signal_power
    return nr

def calculate_ber(original_watermark, extracted_watermark):
    """Calculate Bit Error Rate."""
    if len(original_watermark) != len(extracted_watermark):
        raise ValueError("Original and extracted watermarks must have the same length")
    
    errors = np.sum(original_watermark != extracted_watermark)
    ber = errors / len(original_watermark)
    return ber

# 1. Spread Spectrum Watermarking
def spread_spectrum_embed(host_signal, watermark, alpha=0.1):
    """Embed watermark using spread spectrum technique."""
    # Generate pseudo-random sequence as spreading code
    np.random.seed(42)  # For reproducibility
    spreading_code = np.random.randn(len(host_signal))
    
    # Spread the watermark
    spread_watermark = np.zeros_like(host_signal)
    for i in range(len(watermark)):
        bit_position = i * (len(host_signal) // len(watermark))
        chunk_size = len(host_signal) // len(watermark)
        spread_watermark[bit_position:bit_position+chunk_size] = spreading_code[bit_position:bit_position+chunk_size] * (2 * watermark[i] - 1)
    
    # Embed the spread watermark
    watermarked_signal = host_signal + alpha * spread_watermark
    
    return watermarked_signal, spreading_code

def spread_spectrum_extract(watermarked_signal, original_signal, spreading_code, watermark_length):
    """Extract watermark using spread spectrum technique."""
    # Extract spread watermark
    extracted_spread = watermarked_signal - original_signal
    
    # Recover watermark bits
    extracted_watermark = np.zeros(watermark_length, dtype=int)
    for i in range(watermark_length):
        bit_position = i * (len(watermarked_signal) // watermark_length)
        chunk_size = len(watermarked_signal) // watermark_length
        chunk = extracted_spread[bit_position:bit_position+chunk_size]
        correlation = np.sum(chunk * spreading_code[bit_position:bit_position+chunk_size])
        extracted_watermark[i] = 1 if correlation > 0 else 0
    
    return extracted_watermark

# 2. Echo Hiding
def echo_hiding_embed(host_signal, watermark, delay0=4, delay1=8, alpha=0.5):
    """Embed watermark using echo hiding technique."""
    watermarked_signal = np.copy(host_signal)
    segment_size = len(host_signal) // len(watermark)
    
    for i in range(len(watermark)):
        start = i * segment_size
        end = (i + 1) * segment_size if i < len(watermark) - 1 else len(host_signal)
        segment = host_signal[start:end]
        
        # Choose delay based on watermark bit
        delay = delay1 if watermark[i] == 1 else delay0
        
        # Create echo
        echo = np.zeros_like(segment)
        echo[delay:] = segment[:-delay] * alpha
        
        # Add echo to the original segment
        watermarked_signal[start:end] = segment + echo
    
    return watermarked_signal

def echo_hiding_extract(watermarked_signal, delay0=4, delay1=8, watermark_length=8):
    """Extract watermark using cepstral analysis."""
    segment_size = len(watermarked_signal) // watermark_length
    
    extracted_watermark = np.zeros(watermark_length, dtype=int)
    
    for i in range(watermark_length):
        start = i * segment_size
        end = (i + 1) * segment_size if i < watermark_length - 1 else len(watermarked_signal)
        segment = watermarked_signal[start:end]
        
        # Compute cepstrum
        log_spectrum = np.log(np.abs(fft(segment)) + 1e-10)  # Adding small value to avoid log(0)
        cepstrum = np.abs(ifft(log_spectrum))
        
        # Check delays
        amplitude0 = cepstrum[delay0] if delay0 < len(cepstrum) else 0
        amplitude1 = cepstrum[delay1] if delay1 < len(cepstrum) else 0
        
        extracted_watermark[i] = 1 if amplitude1 > amplitude0 else 0
    
    return extracted_watermark

# 3. Phase Coding
def phase_coding_embed(host_signal, watermark, segment_length=1024):
    """Embed watermark using phase coding technique."""
    # Ensure host signal length is multiple of segment_length
    padding = 0
    if len(host_signal) % segment_length != 0:
        padding = segment_length - (len(host_signal) % segment_length)
        host_signal = np.pad(host_signal, (0, padding), 'constant')
    
    num_segments = len(host_signal) // segment_length
    watermarked_signal = np.copy(host_signal)
    
    # Ensure watermark can fit in the available segments
    if len(watermark) > num_segments:
        watermark = watermark[:num_segments]
    
    for i in range(len(watermark)):
        start = i * segment_length
        end = start + segment_length
        segment = host_signal[start:end]
        
        # Apply FFT
        spectrum = fft(segment)
        magnitude = np.abs(spectrum)
        phase = np.angle(spectrum)
        
        # Modify phase of first component to embed watermark bit
        if watermark[i] == 1:
            phase[1] = np.pi/2
        else:
            phase[1] = -np.pi/2
        
        # Ensure phase continuity between segments
        if i > 0:
            previous_end = (i-1) * segment_length + segment_length
            phase_diff = phase[0] - np.angle(fft(host_signal[(i-1)*segment_length:previous_end]))[-1]
            phase = phase - phase_diff
        
        # Reconstruct signal with modified phase
        modified_spectrum = magnitude * np.exp(1j * phase)
        watermarked_segment = np.real(ifft(modified_spectrum))
        
        watermarked_signal[start:end] = watermarked_segment
    
    # Remove padding if added
    if padding > 0:
        watermarked_signal = watermarked_signal[:-padding]
    
    return watermarked_signal

def phase_coding_extract(watermarked_signal, segment_length=1024, watermark_length=8):
    """Extract watermark using phase coding technique."""
    # Ensure signal length is multiple of segment_length
    padding = 0
    if len(watermarked_signal) % segment_length != 0:
        padding = segment_length - (len(watermarked_signal) % segment_length)
        watermarked_signal = np.pad(watermarked_signal, (0, padding), 'constant')
    
    num_segments = len(watermarked_signal) // segment_length
    extracted_watermark = np.zeros(min(watermark_length, num_segments), dtype=int)
    
    for i in range(min(watermark_length, num_segments)):
        start = i * segment_length
        end = start + segment_length
        segment = watermarked_signal[start:end]
        
        # Apply FFT and extract phase
        spectrum = fft(segment)
        phase = np.angle(spectrum)
        
        # Extract watermark bit based on phase of first component
        extracted_watermark[i] = 1 if phase[1] > 0 else 0
    
    return extracted_watermark

# 4. Quantization Index Modulation (QIM)
def qim_embed(host_signal, watermark, delta=10):
    """Embed watermark using QIM technique."""
    watermarked_signal = np.copy(host_signal)
    segment_size = len(host_signal) // len(watermark)
    
    for i in range(len(watermark)):
        start = i * segment_size
        end = (i + 1) * segment_size if i < len(watermark) - 1 else len(host_signal)
        segment = host_signal[start:end]
        
        # Apply QIM
        for j in range(len(segment)):
            if watermark[i] == 1:
                watermarked_signal[start + j] = delta * round((segment[j] / delta) + 0.25) - 0.25 * delta
            else:
                watermarked_signal[start + j] = delta * round(segment[j] / delta)
    
    return watermarked_signal

def qim_extract(watermarked_signal, watermark_length, delta=10):
    """Extract watermark using QIM technique."""
    segment_size = len(watermarked_signal) // watermark_length
    extracted_watermark = np.zeros(watermark_length, dtype=int)
    
    for i in range(watermark_length):
        start = i * segment_size
        end = (i + 1) * segment_size if i < watermark_length - 1 else len(watermarked_signal)
        segment = watermarked_signal[start:end]
        
        # Quantization residual for bit 0 and bit 1
        residual0 = np.mean([(x - delta * round(x / delta)) ** 2 for x in segment])
        residual1 = np.mean([(x - (delta * round((x / delta) + 0.25) - 0.25 * delta)) ** 2 for x in segment])
        
        extracted_watermark[i] = 1 if residual1 < residual0 else 0
    
    return extracted_watermark

def evaluate_watermarking_methods(host_signal, watermark, attack_function=None):
    """Evaluate all four watermarking methods."""
    results = {
        "Spread Spectrum": {},
        "Echo Hiding": {},
        "Phase Coding": {},
        "QIM": {}
    }
    
    # Parameters
    alpha_ss = 0.1  # Spread spectrum strength
    delay0, delay1 = 4, 8  # Echo hiding delays
    segment_length = 1024  # Phase coding segment length
    delta = 10  # QIM quantization step size
    
    # 1. Spread Spectrum
    print("Evaluating Spread Spectrum...")
    ss_watermarked, spreading_code = spread_spectrum_embed(host_signal, watermark, alpha=alpha_ss)
    results["Spread Spectrum"]["SNR"] = calculate_snr(host_signal, ss_watermarked)
    results["Spread Spectrum"]["NR"] = calculate_nr(host_signal, ss_watermarked)
    
    # Apply attack if specified
    if attack_function:
        ss_attacked = attack_function(ss_watermarked)
        ss_extracted = spread_spectrum_extract(ss_attacked, host_signal, spreading_code, len(watermark))
    else:
        ss_extracted = spread_spectrum_extract(ss_watermarked, host_signal, spreading_code, len(watermark))
    
    results["Spread Spectrum"]["BER"] = calculate_ber(watermark, ss_extracted)
    
    # 2. Echo Hiding
    print("Evaluating Echo Hiding...")
    eh_watermarked = echo_hiding_embed(host_signal, watermark, delay0=delay0, delay1=delay1)
    results["Echo Hiding"]["SNR"] = calculate_snr(host_signal, eh_watermarked)
    results["Echo Hiding"]["NR"] = calculate_nr(host_signal, eh_watermarked)
    
    if attack_function:
        eh_attacked = attack_function(eh_watermarked)
        eh_extracted = echo_hiding_extract(eh_attacked, delay0=delay0, delay1=delay1, watermark_length=len(watermark))
    else:
        eh_extracted = echo_hiding_extract(eh_watermarked, delay0=delay0, delay1=delay1, watermark_length=len(watermark))
    
    results["Echo Hiding"]["BER"] = calculate_ber(watermark, eh_extracted)
    
    # 3. Phase Coding
    print("Evaluating Phase Coding...")
    pc_watermarked = phase_coding_embed(host_signal, watermark, segment_length=segment_length)
    results["Phase Coding"]["SNR"] = calculate_snr(host_signal, pc_watermarked)
    results["Phase Coding"]["NR"] = calculate_nr(host_signal, pc_watermarked)
    
    if attack_function:
        pc_attacked = attack_function(pc_watermarked)
        pc_extracted = phase_coding_extract(pc_attacked, segment_length=segment_length, watermark_length=len(watermark))
    else:
        pc_extracted = phase_coding_extract(pc_watermarked, segment_length=segment_length, watermark_length=len(watermark))
    
    results["Phase Coding"]["BER"] = calculate_ber(watermark, pc_extracted)
    
    # 4. QIM
    print("Evaluating QIM...")
    qim_watermarked = qim_embed(host_signal, watermark, delta=delta)
    results["QIM"]["SNR"] = calculate_snr(host_signal, qim_watermarked)
    results["QIM"]["NR"] = calculate_nr(host_signal, qim_watermarked)
    
    if attack_function:
        qim_attacked = attack_function(qim_watermarked)
        qim_extracted = qim_extract(qim_attacked, len(watermark), delta=delta)
    else:
        qim_extracted = qim_extract(qim_watermarked, len(watermark), delta=delta)
    
    results["QIM"]["BER"] = calculate_ber(watermark, qim_extracted)
    
    return results

def visualize_results(results):
    """Visualize the results of the watermarking methods."""
    methods = list(results.keys())
    metrics = ["SNR", "NR", "BER"]
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    for i, metric in enumerate(metrics):
        values = [results[method][metric] for method in methods]
        axes[i].bar(methods, values)
        axes[i].set_title(f"{metric} Comparison")
        axes[i].set_ylabel(metric)
        
        # For SNR, higher is better
        if metric == "SNR":
            best_method = methods[np.argmax(values)]
            axes[i].text(0.5, 0.9, f"Best: {best_method}", transform=axes[i].transAxes, ha='center')
        else:  # For NR and BER, lower is better
            best_method = methods[np.argmin(values)]
            axes[i].text(0.5, 0.9, f"Best: {best_method}", transform=axes[i].transAxes, ha='center')
    
    plt.tight_layout()
    return fig

# Example usage
def main():
    from scipy.io import wavfile
    
    # Replace 'your_audio_file.wav' with the path to your audio file
    sample_rate, host_signal = wavfile.read('./archive/dev/real/B_0000_5_A.wav')
    
    # Convert to float and normalize if needed
    if host_signal.dtype == np.int16:
        host_signal = host_signal.astype(np.float32) / 32767.0
    elif host_signal.dtype == np.int32:
        host_signal = host_signal.astype(np.float32) / 2147483647.0
    
    # If stereo, use just one channel
    if len(host_signal.shape) > 1:
        host_signal = host_signal[:, 0]
    
    # Create a sample watermark
    watermark = np.array([1, 0, 1, 1, 0, 0, 1, 0], dtype=int)
    
    # Proceed with evaluation
    print("Evaluating watermarking methods with no attack...")
    results = evaluate_watermarking_methods(host_signal, watermark)
    
    # Print results
    print("\nResults:")
    for method, metrics in results.items():
        print(f"{method}:")
        for metric, value in metrics.items():
            print(f"  {metric}: {value:.6f}")
    
    # Define some attacks
    def gaussian_noise_attack(signal, std=0.01):
        """Add Gaussian noise to the signal."""
        return signal + np.random.normal(0, std, len(signal))
    
    def low_pass_filter_attack(signal, cutoff=0.1):
        """Apply low-pass filter to the signal."""
        b, a = signal.butter(5, cutoff, 'low')
        return signal.lfilter(b, a, signal)
    
    def compression_attack(signal, factor=0.8):
        """Simulate compression by keeping only the most significant coefficients."""
        # Apply wavelet transform
        coeffs = pywt.wavedec(signal, 'db4', level=3)
        # Keep only a percentage of the coefficients
        for i in range(len(coeffs)):
            threshold = np.percentile(np.abs(coeffs[i]), 100 * (1 - factor))
            coeffs[i] = pywt.threshold(coeffs[i], threshold, mode='hard')
        # Reconstruct signal
        return pywt.waverec(coeffs, 'db4')
    
    # Evaluate with Gaussian noise attack
    print("\nEvaluating watermarking methods with Gaussian noise attack...")
    noise_attack_results = evaluate_watermarking_methods(
        host_signal, watermark, 
        attack_function=lambda x: gaussian_noise_attack(x)
    )
    
    print("\nResults after Gaussian noise attack:")
    for method, metrics in noise_attack_results.items():
        print(f"{method}:")
        print(f"  BER: {metrics['BER']:.6f}")
    
    # Visualize results
    fig = visualize_results(results)
    plt.savefig('watermarking_metrics_comparison.png')
    plt.close(fig)
    
    print("\nResults visualization saved as 'watermarking_metrics_comparison.png'")

if __name__ == "__main__":
    main()