# Detector Degradation Analysis: SNR and Redshift Reach
This notebook analyzes the impact of detector degradation on SNR and redshift reach for EMRI sources. The degradation factor $d$ is applied as: $\text{SNR}_{\text{degraded}} = \text{SNR}_{\text{original}} / \sqrt{d}$

In [1]:
import h5py
import numpy as np
import glob
import json
import os
import matplotlib.pyplot as plt
import scienceplots
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from scipy.interpolate import interp1d
from ipywidgets import Dropdown, FloatSlider, HBox, Output, interact, widgets

plt.style.use('science')

## Load and Organize Detection Data from All Sources

In [None]:
# Load all detection.h5 files from snr_* directories
detection_files = sorted(glob.glob('snr_*/detection.h5'))
print(f"Found {len(detection_files)} detection.h5 files")

# Dictionary to store data indexed by source_id
source_metadata = {}
source_snr_data = {}

for idx, det_file in enumerate(detection_files):
    source_id = int(det_file.split('_')[1].split('/')[0])
    print(f"Loading source {idx+1}/{len(detection_files)}: {det_file} (ID={source_id})")
    
    with h5py.File(det_file, 'r') as f:
        # Extract scalar metadata
        source_metadata[source_id] = {
            'm1': float(f['m1'][()]),
            'm2': float(f['m2'][()]),
            'a': float(f['a'][()]),
            'p0': float(f['p0'][()]),
            'e0': float(f['e0'][()]),
            'T': float(f['Tpl'][()]),  # Tpl value
        }
        
        # Extract SNR data as function of redshift
        # snr shape: (10 redshifts, 100 realizations)
        snr_data = f['snr'][()]  # (10, 100)
        redshifts = f['redshift'][()]  # (10,)
        
        # Store SNR data indexed by redshift
        source_snr_data[source_id] = {}
        for z_idx, z_val in enumerate(redshifts):
            source_snr_data[source_id][float(z_val)] = snr_data[z_idx, :]

print("\nData loading complete!")
print(f"Loaded metadata for {len(source_metadata)} sources")
print(f"Loaded SNR data for {len(source_snr_data)} sources")

In [None]:
# Extract unique parameter values for dropdowns
Tpl_values = sorted(set(source_metadata[src]['T'] for src in source_metadata))
a_values = sorted(set(source_metadata[src]['a'] for src in source_metadata))

# Collect all unique redshifts across all sources
all_redshifts = set()
for src_id in source_snr_data:
    all_redshifts.update(source_snr_data[src_id].keys())
all_redshifts = sorted(all_redshifts)

print(f"Unique Tpl values: {Tpl_values}")
print(f"Unique spin (a) values: {a_values}")
print(f"Number of unique redshifts: {len(all_redshifts)}")
print(f"Redshift range: {min(all_redshifts):.6f} to {max(all_redshifts):.6f}")

## Interactive Degradation Analysis Plots
Adjust the sliders and dropdowns to see how detector degradation affects SNR and redshift reach. The degradation factor $d$ scales the SNR as $\text{SNR}_{\text{degraded}} = \text{SNR}_{\text{original}} / \sqrt{d}$.

In [None]:
# Create dropdowns for Tpl, spin, redshift, and degradation
tpl_dropdown_deg = Dropdown(
    options=[(f'Tpl = {t:.2f}', t) for t in Tpl_values],
    description='Tpl:',
    value=Tpl_values[0]
)

spin_dropdown_deg = Dropdown(
    options=[(f'a = {a:.2f}', a) for a in a_values],
    description='Spin (a):',
    value=a_values[-1]
)

redshift_dropdown_deg = Dropdown(
    options=[(f'z = {z:.6f}', z) for z in all_redshifts],
    description='Redshift:',
    value=all_redshifts[0]
)

# SNR threshold slider for redshift plot in degradation analysis
snr_threshold_slider_deg = widgets.FloatSlider(
    value=20,
    min=5,
    max=100,
    step=5,
    description='SNR threshold:',
    continuous_update=True
)

degradation_slider = widgets.FloatSlider(
    value=2.0,
    min=1.0,
    max=5.0,
    step=0.5,
    description='Degradation (d):',
    continuous_update=True
)

output_deg_plot = Output()

figsize = (6, 12)

In [None]:
def plot_degradation_effect(tpl_val, spin_a, z_val, snr_threshold, degradation):
    """
    Plot showing degradation effect on SNR and redshift.
    Top: SNR vs m1 with arrows showing degradation impact
    Bottom: Redshift vs m1 with arrows showing degradation impact
    """
    tolerance = 1e-6
    matching_sources = []
    
    # Filter sources by Tpl and spin
    for src_idx in sorted(source_metadata.keys()):
        src_a = source_metadata[src_idx]['a']
        src_tpl = source_metadata[src_idx]['T']
        
        if abs(src_a - spin_a) < tolerance and abs(src_tpl - tpl_val) < tolerance:
            matching_sources.append(src_idx)
    
    if not matching_sources:
        with output_deg_plot:
            output_deg_plot.clear_output(wait=True)
            print(f"No sources found for Tpl={tpl_val:.2f}, a={spin_a:.2f}")
        return
    
    # Extract SNR data for specified redshift
    snr_data = {}
    for src_idx in matching_sources:
        m1 = source_metadata[src_idx]['m1']
        m2 = source_metadata[src_idx]['m2']
        
        if z_val not in source_snr_data[src_idx]:
            continue
            
        snr_array = source_snr_data[src_idx][z_val]  # Shape: (100,)
        snr_median = np.median(snr_array)
        
        if m2 not in snr_data:
            snr_data[m2] = {'m1': [], 'snr_orig': []}
        snr_data[m2]['m1'].append(m1)
        snr_data[m2]['snr_orig'].append(snr_median)
    
    # Extract redshift reach data
    z_data = {}
    for src_idx in matching_sources:
        m1 = source_metadata[src_idx]['m1']
        m2 = source_metadata[src_idx]['m2']
        
        # Get SNR vs redshift
        z_snr_dict = source_snr_data[src_idx]
        z_vals_list = sorted(z_snr_dict.keys())
        snr_median_per_z = []
        
        for z in z_vals_list:
            snr_array = z_snr_dict[z]
            snr_median_per_z.append(np.median(snr_array))
        
        snr_median_per_z = np.array(snr_median_per_z)
        z_vals_array = np.array(z_vals_list)
        
        if snr_threshold > np.max(snr_median_per_z):
            continue
        
        try:
            # Original redshift at SNR threshold
            interp_func = interp1d(snr_median_per_z, z_vals_array, kind='linear',
                                   bounds_error=False, fill_value='extrapolate')
            z_at_snr = interp_func(snr_threshold)
            
            # Degraded SNR values
            snr_median_per_z_deg = snr_median_per_z / np.sqrt(degradation)
            interp_func_deg = interp1d(snr_median_per_z_deg, z_vals_array, kind='linear',
                                       bounds_error=False, fill_value='extrapolate')
            z_at_snr_deg = interp_func_deg(snr_threshold)
        except:
            continue
        
        if m2 not in z_data:
            z_data[m2] = {'m1': [], 'z_orig': [], 'z_deg': []}
        z_data[m2]['m1'].append(m1)
        z_data[m2]['z_orig'].append(z_at_snr)
        z_data[m2]['z_deg'].append(z_at_snr_deg)
    
    # Create figure with two subplots
    with output_deg_plot:
        output_deg_plot.clear_output(wait=True)
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True)
        
        colors = plt.cm.tab20(np.linspace(0, 1, max(len(snr_data), len(z_data)) or 1))
        
        # TOP PLOT: SNR vs m1 with degradation arrows
        for idx, m2 in enumerate(sorted(snr_data.keys())):
            m1_vals = np.array(snr_data[m2]['m1'])
            snr_orig = np.array(snr_data[m2]['snr_orig'])
            snr_deg = snr_orig / np.sqrt(degradation)
            
            sort_idx = np.argsort(m1_vals)
            m1_sorted = m1_vals[sort_idx]
            snr_orig_sorted = snr_orig[sort_idx]
            snr_deg_sorted = snr_deg[sort_idx]
            
            # Plot original and degraded SNR
            ax1.plot(m1_sorted, snr_orig_sorted, 'o-', color=colors[idx],
                    markersize=7, linewidth=1.5, label=f'm2={m2:.0f}', alpha=0.7)
            ax1.plot(m1_sorted, snr_deg_sorted, 's--', color=colors[idx],
                    markersize=6, linewidth=1.5, alpha=0.5)
            
            # Draw downward arrows showing degradation
            for i in range(len(m1_sorted)):
                ax1.annotate('', xy=(m1_sorted[i], snr_deg_sorted[i]),
                           xytext=(m1_sorted[i], snr_orig_sorted[i]),
                           arrowprops=dict(arrowstyle='->', color=colors[idx],
                                         lw=1.5, alpha=0.6))
        
        ax1.set_xlabel('Primary Mass m1 ($M_\\odot$)', fontsize=18)
        ax1.set_ylabel('SNR', fontsize=18)
        ax1.set_xscale('log')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        ax1.set_title(f'SNR Degradation | Tpl={tpl_val:.2f}, a={spin_a:.2f}, z={z_val:.6f}\nd={degradation:.1f}', 
                     fontsize=18, fontweight='bold')
        ax1.legend(fontsize=9, loc='upper right')
        
        # BOTTOM PLOT: Redshift vs m1 with degradation arrows
        for idx, m2 in enumerate(sorted(z_data.keys())):
            m1_vals = np.array(z_data[m2]['m1'])
            z_orig = np.array(z_data[m2]['z_orig'])
            z_deg = np.array(z_data[m2]['z_deg'])
            
            sort_idx = np.argsort(m1_vals)
            m1_sorted = m1_vals[sort_idx]
            z_orig_sorted = z_orig[sort_idx]
            z_deg_sorted = z_deg[sort_idx]
            
            # Plot original and degraded redshift
            ax2.plot(m1_sorted, z_orig_sorted, 'o-', color=colors[idx],
                    markersize=7, linewidth=1.5, label=f'm2={m2:.0f}', alpha=0.7)
            ax2.plot(m1_sorted, z_deg_sorted, 's--', color=colors[idx],
                    markersize=6, linewidth=1.5, alpha=0.5)
            
            # Draw upward arrows
            for i in range(len(m1_sorted)):
                ax2.annotate('', xy=(m1_sorted[i], z_deg_sorted[i]),
                           xytext=(m1_sorted[i], z_orig_sorted[i]),
                           arrowprops=dict(arrowstyle='->', color=colors[idx],
                                         lw=1.5, alpha=0.6))
        
        ax2.set_xlabel('Primary Mass m1 ($M_\\odot$)', fontsize=18)
        ax2.set_ylabel('Redshift z', fontsize=18)
        ax2.set_xscale('log')
        ax2.set_yscale('log')
        ax2.grid(True, alpha=0.3)
        ax2.set_title(f'Redshift Loss | SNR threshold={snr_threshold:.1f}, d={degradation:.1f}', 
                     fontsize=18, fontweight='bold')
        ax2.legend(fontsize=9, loc='upper right')
        
        plt.tight_layout()
        plt.show()

In [None]:
# Create and display widget
controls_deg = HBox([tpl_dropdown_deg, spin_dropdown_deg, redshift_dropdown_deg, snr_threshold_slider_deg, degradation_slider])
display(controls_deg)
display(output_deg_plot)

# Create interactive plot
_ = interact(plot_degradation_effect, tpl_val=tpl_dropdown_deg,
             spin_a=spin_dropdown_deg, z_val=redshift_dropdown_deg,
             snr_threshold=snr_threshold_slider_deg, degradation=degradation_slider)