In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from datetime import datetime, timedelta
from scipy import signal, stats  # Added stats for kurtosis/skewness
try:
    import mplcursors
    HAVE_MPLCURSORS = True
except ImportError:
    HAVE_MPLCURSORS = False
    print("Note: Install 'mplcursors' for interactive hover annotations on plots.")

# ---------------------------
# User Input: File paths and options
# ---------------------------
data_file_path = input("Enter path to seismic data file (.mcs or .segy): ").strip()
while data_file_path == "":
    data_file_path = input("Data file is required. Please enter the path to the data file: ").strip()

log_file_path = input("Enter path to log file (.log) [press Enter if none]: ").strip()
aux_file_path = input("Enter path to aux file (.aux) [press Enter if none]: ").strip()

# Station metadata (for labeling plots); user can override if desired
default_station = "KNR24"
default_network = "1B"
station_name = input(f"Enter station name [default: {default_station}]: ").strip() or default_station
network_code = input(f"Enter network code [default: {default_network}]: ").strip() or default_network

# Filtering and processing parameters
# Bandpass filter range for analysis:
bp_low = input("Enter low-cut frequency for bandpass filter [default 0.1 Hz]: ").strip()
bp_high = input("Enter high-cut frequency for bandpass filter [default 20 Hz]: ").strip()
lowcut_freq = float(bp_low) if bp_low else 0.1
highcut_freq = float(bp_high) if bp_high else 20.0

# Notch filter for powerline noise:
notch = input("Enter powerline frequency for notch filter (0 to skip) [default 50 Hz]: ").strip()
notch_freq = float(notch) if notch else 50.0
if notch_freq == 0:
    notch_freq = None  # Use None to indicate no notch filtering

# Prompt for analysis sections
adv_choice = input("Perform advanced signal analysis (PSD, correlations, etc)? [Y/n]: ").strip().lower()
do_advanced = False if adv_choice == 'n' or adv_choice == 'no' else True

clean_choice = input("Perform data cleaning and comparison analysis? [Y/n]: ").strip().lower()
do_cleaning = False if clean_choice == 'n' or clean_choice == 'no' else True

inv_choice = input("Perform travel-time inversion analysis? [Y/n]: ").strip().lower()
do_inversion = False if inv_choice == 'n' or inv_choice == 'no' else True

phase_choice = "both"
if do_inversion:
    phase_choice = input("Pick phases for inversion: 'P' for P-waves only, 'S' for S-waves only, 'Both' for both [default: Both]: ").strip().lower() or "both"
    # Normalize input
    if phase_choice not in ['p', 's', 'both']:
        phase_choice = 'both'
pick_p = (phase_choice in ['p', 'both'])
pick_s = (phase_choice in ['s', 'both'])

# Option to save plots to files
save_plots_input = input("Save plots to files? [y/N]: ").strip().lower()
save_plots = True if save_plots_input == 'y' or save_plots_input == 'yes' else False
if save_plots:
    output_dir = input("Enter directory to save plots [default: current directory]: ").strip() or "."
    os.makedirs(output_dir, exist_ok=True)
    print(f"Plots will be saved in: {os.path.abspath(output_dir)}")

# ---------------------------
# Load and parse log/aux files if provided
# ---------------------------
def parse_log_file(log_file_path):
    """Parse the .log file for recording information."""
    log_data = {
        'start_time': None,
        'end_time': None,
        'sample_rate': None,
        'file_size': None,
        'checksum': None,
        'errors': []
    }
    if not log_file_path or not os.path.exists(log_file_path):
        return log_data
    try:
        with open(log_file_path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            line = line.strip()
            if 'Start time:' in line:
                try:
                    time_str = line.split('Start time:')[1].strip()
                    log_data['start_time'] = datetime.fromisoformat(time_str.replace(' ', 'T'))
                except Exception:
                    log_data['errors'].append(f"Could not parse start time: {line}")
            elif 'End time:' in line:
                try:
                    time_str = line.split('End time:')[1].strip()
                    log_data['end_time'] = datetime.fromisoformat(time_str.replace(' ', 'T'))
                except Exception:
                    log_data['errors'].append(f"Could not parse end time: {line}")
            elif 'Sample rate:' in line:
                try:
                    rate_str = line.split('Sample rate:')[1].strip().split(' ')[0]
                    log_data['sample_rate'] = float(rate_str)
                except Exception:
                    log_data['errors'].append(f"Could not parse sample rate: {line}")
            elif 'File size:' in line:
                try:
                    size_str = line.split('File size:')[1].strip().split(' ')[0]
                    log_data['file_size'] = int(size_str)
                except Exception:
                    log_data['errors'].append(f"Could not parse file size: {line}")
            elif 'Checksum:' in line:
                try:
                    checksum_str = line.split('Checksum:')[1].strip()
                    log_data['checksum'] = checksum_str
                except Exception:
                    log_data['errors'].append(f"Could not parse checksum: {line}")
    except Exception as e:
        print(f"Error reading log file: {e}")
    return log_data

def parse_aux_file(aux_file_path):
    """Parse the .aux file for auxiliary information."""
    aux_data = {
        'gps_times': [],
        'temperature_readings': [],
        'battery_voltage': [],
        'clock_drift': [],
        'events': []
    }
    if not aux_file_path or not os.path.exists(aux_file_path):
        return aux_data
    try:
        with open(aux_file_path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            line = line.strip()
            if line.startswith('GPS:'):
                try:
                    parts = line.split()
                    if len(parts) >= 3:
                        time_str = parts[1] + 'T' + parts[2]
                        aux_data['gps_times'].append(datetime.fromisoformat(time_str))
                except Exception:
                    pass
            elif line.startswith('TEMP:'):
                try:
                    temp_str = line.split('TEMP:')[1].strip().split(' ')[0]
                    aux_data['temperature_readings'].append(float(temp_str))
                except Exception:
                    pass
            elif line.startswith('BATT:'):
                try:
                    batt_str = line.split('BATT:')[1].strip().split(' ')[0]
                    aux_data['battery_voltage'].append(float(batt_str))
                except Exception:
                    pass
            elif line.startswith('DRIFT:'):
                try:
                    drift_str = line.split('DRIFT:')[1].strip().split(' ')[0]
                    aux_data['clock_drift'].append(float(drift_str))
                except Exception:
                    pass
            elif line.startswith('EVENT:'):
                try:
                    parts = line.split()
                    if len(parts) >= 3:
                        time_str = parts[1] + 'T' + parts[2]
                        aux_data['events'].append(datetime.fromisoformat(time_str))
                except Exception:
                    pass
    except Exception as e:
        print(f"Error reading aux file: {e}")
    return aux_data

# Parse provided log/aux files
log_data = parse_log_file(log_file_path)
aux_data = parse_aux_file(aux_file_path)

if log_data['start_time']:
    print(f"Log file start time: {log_data['start_time']}")
if log_data['end_time']:
    print(f"Log file end time: {log_data['end_time']}")
if log_data['sample_rate']:
    print(f"Log file sample rate: {log_data['sample_rate']} Hz")
if aux_data['gps_times']:
    print(f"Aux file GPS entries: {len(aux_data['gps_times'])}")
if aux_data['events']:
    print(f"Aux file events: {len(aux_data['events'])} detected events")

# ---------------------------
# Read main seismic data file
# ---------------------------
# Determine file format and reading method
file_ext = os.path.splitext(data_file_path)[1].lower()
use_obspy = False
if file_ext in ['.segy', '.sgy']:
    try:
        from obspy import read
        use_obspy = True
    except ImportError:
        print("ObsPy not installed, cannot read SEGY directly. Please install obspy or convert the file.")
        use_obspy = False

if use_obspy:
    print(f"Reading SEGY file: {data_file_path} ...")
    st = read(data_file_path)
    # Merge to one trace per channel if needed
    st.merge(fill_value=0)
    # Assuming multiple channels (traces) in st correspond to different components
    num_channels = len(st)
    sample_rate = st[0].stats.sampling_rate
    recording_start = st[0].stats.starttime.datetime  # Obspy UTCDateTime to datetime
    data = np.array([tr.data for tr in st]).T.astype(np.float64)  # shape (n_samples, n_channels)
    print(f"Loaded {num_channels} channels from SEGY. Sample rate = {sample_rate} Hz.")
else:
    # For .mcs or other binary format similar to the provided code
    print(f"Reading binary file: {data_file_path} ...")
    try:
        file_size = os.path.getsize(data_file_path)
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: File '{data_file_path}' not found.")
    # Use metadata sample rate if log did not provide
    sample_rate = int(log_data['sample_rate'] or 100)  # default 100 if not in log
    # The .mcs format specifics (24-bit, 4 channels, 4096-byte header) as in original code
    num_channels = 4
    bits_per_sample = 24
    header_size_bytes = 4096
    bytes_per_sample = bits_per_sample // 8  # 3 bytes
    bytes_per_time_step = bytes_per_sample * num_channels

    total_data_bytes = file_size - header_size_bytes
    if total_data_bytes <= 0:
        raise ValueError("File is too small or header size is larger than file size.")
    with open(data_file_path, 'rb') as f:
        f.seek(header_size_bytes)
        raw_bytes = f.read(total_data_bytes)
    raw_data = np.frombuffer(raw_bytes, dtype=np.uint8)
    # Ensure we have whole time steps
    total_bytes = len(raw_data)
    complete_time_steps = total_bytes // bytes_per_time_step
    if complete_time_steps == 0:
        raise ValueError("No data samples found in file.")
    usable_bytes = complete_time_steps * bytes_per_time_step
    if usable_bytes < total_bytes:
        raw_data = raw_data[:usable_bytes]
        print(f"Warning: {total_bytes - usable_bytes} trailing bytes ignored (incomplete sample).")

    # Reshape raw data to [time_steps x bytes_per_time_step]
    raw_data_reshaped = raw_data.reshape(-1, bytes_per_time_step)
    # Prepare output array
    data = np.zeros((raw_data_reshaped.shape[0], num_channels), dtype=np.float64)
    # Gains: if available from metadata or default to 1
    hydro_gain = 1.0
    seis_gain = 1.0
    # If metadata dictionary had these, use them
    if log_data['sample_rate']:
        sample_rate = int(log_data['sample_rate'])
    # If known gain values in log or metadata can be used, else assume 1
    # (Original metadata in code snippet hard-coded 4 and 1; we use 1 if unknown)
    gain_factors = np.array([hydro_gain, seis_gain, seis_gain, seis_gain])

    # Convert 3-byte values to int and apply gains
    for ch in range(num_channels):
        # indices in each frame for this channel's 3 bytes
        idx = ch * bytes_per_sample
        byte1 = raw_data_reshaped[:, idx]      # MSB
        byte2 = raw_data_reshaped[:, idx + 1]  # mid byte
        byte3 = raw_data_reshaped[:, idx + 2]  # LSB
        # combine bytes into 24-bit signed int
        samples = (byte1.astype(np.int32) << 16) | (byte2.astype(np.int32) << 8) | byte3.astype(np.int32)
        # two's complement adjustment for 24-bit
        samples = np.where(samples > 0x7FFFFF, samples - 0x1000000, samples)
        # apply gain
        data[:, ch] = samples.astype(np.float64) / gain_factors[ch]

    recording_start = log_data['start_time'] or datetime.now()
    print(f"Binary file read complete. Samples: {data.shape[0]}, Channels: {num_channels}, Sample rate: {sample_rate} Hz.")

# Channel naming based on assumptions or metadata (adjust if needed)
channel_names = []
channel_units = []
if num_channels == 4:
    channel_names = ['Hydrophone', 'Seismometer Z', 'Seismometer Y', 'Seismometer X']
    channel_units = ['Pressure [Pa]', 'Amplitude', 'Amplitude', 'Amplitude']
else:
    # For unknown channel count, label generically
    channel_names = [f"Channel {i+1}" for i in range(num_channels)]
    channel_units = ['Amplitude'] * num_channels

# ---------------------------
# Signal amplitude detection and basic analysis
# ---------------------------
# Define helper to check if a signal chunk has significant amplitude (not just noise)
def has_amplitude(signal_data, noise_threshold=0.01, flat_threshold=1e-6):
    if len(signal_data) == 0:
        return False
    std_dev = np.std(signal_data)
    data_range = np.ptp(signal_data)
    max_abs = np.max(np.abs(signal_data))
    return (std_dev > noise_threshold and data_range > flat_threshold and max_abs > noise_threshold * 2)

# Find end of data where no significant amplitude is detected (assume after event)
def find_amplitude_end(data, sample_rate, chunk_size_seconds=10, noise_threshold=0.01):
    chunk_size = int(chunk_size_seconds * sample_rate)
    total_samples = data.shape[0]
    for start in range(0, total_samples, chunk_size):
        end = min(start + chunk_size, total_samples)
        chunk = data[start:end, :]
        # if no channel has amplitude in this chunk, consider this the end
        if not any(has_amplitude(chunk[:, ch], noise_threshold) for ch in range(data.shape[1])):
            return start
    return total_samples

# Determine region with signal (assuming after that is just noise)
amplitude_end_idx = find_amplitude_end(data, sample_rate)
if amplitude_end_idx < data.shape[0]:
    print(f"No significant signal after sample {amplitude_end_idx} (approx {amplitude_end_idx/sample_rate:.1f}s). Truncating data to signal region.")
    data_with_amplitude = data[:amplitude_end_idx, :]
else:
    data_with_amplitude = data  # no quiet part detected, use all data
total_samples = data_with_amplitude.shape[0]
total_duration_sec = total_samples / sample_rate
recording_end_amplitude = recording_start + timedelta(seconds=total_duration_sec)
print(f"Using {total_samples} samples (~{total_duration_sec:.2f} seconds) containing the signal/event.")

# Compute amplitude envelope for each channel (to analyze signal strength over time)
def calculate_amplitude_envelope(signal_data, fs, smooth_window=1.0):
    analytic = signal.hilbert(signal_data)
    envelope = np.abs(analytic)
    # Smooth envelope with moving average of specified window (seconds)
    window_pts = int(smooth_window * fs)
    if window_pts > 1:
        envelope = uniform_filter1d(envelope, size=window_pts)
    return envelope

envelopes = np.zeros_like(data_with_amplitude)
for i in range(num_channels):
    envelopes[:, i] = calculate_amplitude_envelope(data_with_amplitude[:, i], sample_rate, window_size=1.0)

# Plot overall signal strength (envelope) over time for each channel (log scale)
plt.figure(figsize=(10, 6))
for i in range(num_channels):
    if i == 0:
        plt.semilogy(np.arange(total_samples)/sample_rate, np.abs(data_with_amplitude[:, i]), 
                     label=f"{channel_names[i]} (Pressure)", alpha=0.7)
    else:
        plt.semilogy(np.arange(total_samples)/sample_rate, envelopes[:, i], 
                     label=f"{channel_names[i]} (Amplitude)", alpha=0.7)
plt.xlabel('Time [s]')
plt.ylabel('Signal Strength (log scale)')
plt.title(f'Signal Strength Over Time - Station {station_name}')
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.5)
# Save and show plot
if save_plots:
    fname = f"{station_name}_{network_code}_signal_strength.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    print(f"Saved plot: {fname}")
print("Displaying signal strength plot... Close the figure window to continue.")
if HAVE_MPLCURSORS:
    mplcursors.cursor(hover=True)
plt.show()

# ---------------------------
# Spectrogram for each channel (if signal present)
# ---------------------------
print("Generating spectrograms for each channel...")
fig, axes = plt.subplots( (num_channels+1)//2, 2, figsize=(12, 6*((num_channels+1)//2)) )
axes = axes.flatten() if num_channels > 1 else [axes]
for i in range(num_channels):
    ax = axes[i]
    if has_amplitude(data_with_amplitude[:, i]):
        f, t, Sxx = signal.spectrogram(data_with_amplitude[:, i], fs=sample_rate, nperseg=256, noverlap=128, nfft=512, scaling='density')
        Sxx_db = 10 * np.log10(Sxx + 1e-10)
        im = ax.pcolormesh(t, f, Sxx_db, shading='gouraud', cmap='viridis')
        ax.set_ylim(0, min(50, sample_rate/2))
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Frequency [Hz]')
        ax.set_title(f"{channel_names[i]} Spectrogram")
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('PSD [dB]')
    else:
        ax.text(0.5, 0.5, 'No signal', transform=ax.transAxes, ha='center', va='center', fontsize=12, color='gray')
        ax.set_title(f"{channel_names[i]} Spectrogram")
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Frequency [Hz]')
# Hide any unused subplot axes
for j in range(num_channels, len(axes)):
    axes[j].axis('off')
plt.suptitle(f"Spectrograms - Station {station_name}", fontsize=14)
plt.tight_layout()
# Save and show spectrograms
if save_plots:
    fname = f"{station_name}_{network_code}_spectrograms.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    print(f"Saved plot: {fname}")
print("Displaying spectrograms... Close the figure window to continue.")
# (Hover cursor not as useful on spectrogram heatmap, so we skip mplcursors here to avoid clutter)
plt.show()

# ---------------------------
# Optional: Plot segmented waveform (e.g., first N segments of data)
# ---------------------------
try:
    seg_count = int(input("Enter number of 5-second segments to plot from signal (0 to skip) [default 0]: ") or 0)
except:
    seg_count = 0
segment_duration = 5  # seconds per segment
samples_per_segment = int(segment_duration * sample_rate)
if seg_count > 0:
    max_segments = data_with_amplitude.shape[0] // samples_per_segment
    seg_count = min(seg_count, max_segments)
    for segment in range(seg_count):
        start_idx = segment * samples_per_segment
        end_idx = start_idx + samples_per_segment
        seg_data = data_with_amplitude[start_idx:end_idx, :]
        seg_time = np.arange(start_idx, end_idx) / sample_rate
        fig, axes = plt.subplots(num_channels, 1, figsize=(12, 8), sharex=True)
        for i in range(num_channels):
            ax = axes[i]
            if has_amplitude(seg_data[:, i]):
                ax.plot(seg_time, seg_data[:, i], color='C0', linewidth=0.8)
                # Mark events (from aux file) if they fall in this segment
                for ev_time in aux_data.get('events', []):
                    ev_sec = (ev_time - recording_start).total_seconds()
                    if ev_sec >= seg_time[0] and ev_sec <= seg_time[-1]:
                        ax.axvline(x=ev_sec, color='red', ls='--', lw=0.8)
                        ax.text(ev_sec, 0.9*ax.get_ylim()[1], 'EVENT', color='red', rotation=90, va='top', fontsize=8)
                ax.set_ylabel(channel_units[i])
            else:
                ax.text(0.5, 0.5, 'No signal', transform=ax.transAxes, ha='center', va='center', fontsize=10, color='gray')
            ax.set_title(f"{channel_names[i]}")
            ax.grid(True, ls='--', alpha=0.5)
        axes[-1].set_xlabel("Time [s]")
        plt.suptitle(f"{station_name} Segment {segment+1}: {segment*segment_duration}-{(segment+1)*segment_duration}s")
        plt.tight_layout()
        # Save and show
        if save_plots:
            fname = f"{station_name}_{network_code}_segment{segment+1}_{segment*segment_duration}-{(segment+1)*segment_duration}s.png"
            plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
            print(f"Saved plot: {fname}")
        print(f"Displaying segment {segment+1}/{seg_count}... Close the figure to continue.")
        if HAVE_MPLCURSORS:
            mplcursors.cursor(hover=True)
        plt.show()

# ---------------------------
# Advanced data analysis (optional)
# ---------------------------
if do_advanced:
    print("\n=== Advanced Data Analysis ===")
    # Define advanced processing functions (bandpass, PSD, correlations, etc.)
    def apply_bandpass_filter(data, fs, lowcut=1.0, highcut=20.0, order=4):
        nyquist = 0.5 * fs
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = signal.butter(order, [low, high], btype='band')
        return signal.filtfilt(b, a, data)
    def calculate_psd_vals(data, fs, nperseg=None):
        if nperseg is None:
            nperseg = min(256, len(data)//2)
            if nperseg < 8:
                nperseg = len(data)
        f, Pxx = signal.welch(data, fs=fs, nperseg=nperseg, scaling='density')
        return f, Pxx
    def calculate_cross_corr(x, y, max_lag_samples=1000):
        corr = signal.correlate(x, y, mode='full')
        lags = signal.correlation_lags(len(x), len(y), mode='full')
        mask = np.abs(lags) <= max_lag_samples
        return lags[mask], corr[mask]
    def calculate_auto_corr(x, max_lag_samples=500):
        corr = signal.correlate(x, x, mode='full')
        lags = signal.correlation_lags(len(x), len(x), mode='full')
        corr = corr / np.max(corr)  # normalize
        mask = np.abs(lags) <= max_lag_samples
        return lags[mask], corr[mask]
    def detect_events_sta_lta(data, fs, threshold=5.0, min_distance=1000):
        short_win = int(1 * fs)
        long_win = int(10 * fs)
        sta = uniform_filter1d(np.abs(data), size=short_win)
        lta = uniform_filter1d(np.abs(data), size=long_win)
        ratio = sta / (lta + 1e-10)
        peaks, _ = signal.find_peaks(ratio, height=threshold, distance=min_distance)
        return peaks, ratio
    def analyze_signal_quality(data, fs):
        metrics = {}
        # SNR: ratio of total variance to variance of high-frequency residual as noise
        smooth = uniform_filter1d(data, size=min(100, len(data)//10))
        noise = data - smooth
        signal_power = np.var(data)
        noise_power = np.var(noise)
        metrics['snr_db'] = 10 * np.log10(signal_power / (noise_power + 1e-10))
        # Dynamic range
        metrics['dynamic_range_db'] = 20 * np.log10(np.max(np.abs(data)) / (np.std(data) + 1e-10))
        # Zero-crossing rate
        zero_crossings = np.where(np.diff(np.sign(data)))[0]
        metrics['zero_crossing_rate'] = len(zero_crossings) / (len(data) / fs)
        return metrics
    def perform_polarization_analysis(z, ns, ew, fs, window_size=1000):
        results = []
        for i in range(0, len(z) - window_size, window_size // 2):
            win_z = z[i:i+window_size]
            win_ns = ns[i:i+window_size]
            win_ew = ew[i:i+window_size]
            cov = np.cov([win_z, win_ns, win_ew])
            eigenvals, eigenvecs = np.linalg.eigh(cov)
            # sort eigenvalues
            eigenvals = np.sort(eigenvals)
            # compute polarization metrics
            rect = 1 - (eigenvals[0] + eigenvals[1]) / (2*eigenvals[2] + 1e-10)
            plan = 1 - (2*eigenvals[0]) / (eigenvals[1] + eigenvals[2] + 1e-10)
            results.append({'time': i/fs, 'rectilinearity': rect, 'planarity': plan})
        return results

    # Apply bandpass filter (using user-provided lowcut/highcut or defaults)
    filtered_data = np.zeros_like(data_with_amplitude)
    for i in range(num_channels):
        filtered_data[:, i] = apply_bandpass_filter(data_with_amplitude[:, i], sample_rate, lowcut=lowcut_freq, highcut=highcut_freq)
    print(f"Applied bandpass filter ({lowcut_freq}-{highcut_freq} Hz) to all channels.")

    # Power Spectral Density for each channel
    plt.figure(figsize=(10, 6))
    for i in range(num_channels):
        f, Pxx = calculate_psd_vals(filtered_data[:, i], sample_rate)
        plt.semilogy(f, Pxx, label=channel_names[i], alpha=0.8)
    plt.title('Power Spectral Density (bandpass filtered data)')
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('PSD [arb. units^2/Hz]')
    plt.xlim(0, highcut_freq*1.5)
    plt.grid(True, which="both", ls="--", alpha=0.5)
    plt.legend()
    if save_plots:
        fname = f"{station_name}_{network_code}_PSD.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying PSD plot... Close the figure to continue.")
    if HAVE_MPLCURSORS:
        mplcursors.cursor(hover=True)
    plt.show()

    # Cross-correlation between seismometer components (if 3 components exist)
    if num_channels >= 3:
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        axes = axes.flatten()
        pairs = [(2, 3), (1, 3), (1, 2)]  # Y-X, Z-X, Z-Y if using names above
        titles = ['Y-X Corr', 'Z-X Corr', 'Z-Y Corr']
        for idx, (ch1, ch2) in enumerate(pairs):
            if ch1 < num_channels and ch2 < num_channels:
                lags, corr = calculate_cross_corr(filtered_data[:, ch1], filtered_data[:, ch2], max_lag_samples=500)
                corr_norm = corr / (np.max(np.abs(corr)) + 1e-10)
                axes[idx].plot(lags/sample_rate, corr_norm, 'k-')
                axes[idx].set_title(titles[idx])
                axes[idx].set_xlabel('Lag [s]')
                axes[idx].set_ylabel('Correlation (norm)')
                axes[idx].grid(True, ls='--', alpha=0.5)
        axes[-1].axis('off')
        plt.suptitle('Cross-correlation between components')
        plt.tight_layout()
        if save_plots:
            fname = f"{station_name}_{network_code}_cross_correlation.png"
            plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
            print(f"Saved plot: {fname}")
        print("Displaying cross-correlation plots... Close the figure to continue.")
        plt.show()

    # Auto-correlation for first up to 4 channels
    fig, axes = plt.subplots(min(4, num_channels)//2 + 1, 2, figsize=(12, 8))
    axes = axes.flatten()
    for i in range(min(num_channels, 4)):
        lags, acorr = calculate_auto_corr(filtered_data[:, i], max_lag_samples=200)
        axes[i].plot(lags/sample_rate, acorr, 'b-')
        axes[i].set_title(f'Autocorrelation - {channel_names[i]}')
        axes[i].set_xlabel('Lag [s]')
        axes[i].set_ylabel('Correlation (norm)')
        axes[i].grid(True, ls='--', alpha=0.5)
    # Hide unused axes
    for j in range(min(4, num_channels), len(axes)):
        axes[j].axis('off')
    plt.tight_layout()
    if save_plots:
        fname = f"{station_name}_{network_code}_autocorr.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying autocorrelation plots... Close figure to continue.")
    plt.show()

    # Event detection (STA/LTA) on each seismometer channel
    plt.figure(figsize=(10, 6))
    for i in range(num_channels):
        if channel_names[i].lower().startswith('seismometer') or num_channels <= 3:
            peaks, ratio = detect_events_sta_lta(filtered_data[:, i], sample_rate, threshold=3.0, min_distance=int(1*sample_rate))
            plt.plot(np.arange(len(ratio))/sample_rate, ratio, label=f"{channel_names[i]}")
            plt.plot(np.array(peaks)/sample_rate, ratio[peaks], 'ro', markersize=4)
            print(f"{channel_names[i]}: {len(peaks)} event(s) detected by STA/LTA.")
    plt.title('Event Detection (STA/LTA)')
    plt.xlabel('Time [s]')
    plt.ylabel('STA/LTA Ratio')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, ls='--', alpha=0.5)
    if save_plots:
        fname = f"{station_name}_{network_code}_STA_LTA.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying STA/LTA event detection plot... Close figure to continue.")
    if HAVE_MPLCURSORS:
        mplcursors.cursor(hover=True)
    plt.show()

    # Signal quality metrics
    for i in range(num_channels):
        metrics = analyze_signal_quality(filtered_data[:, i], sample_rate)
        print(f"{channel_names[i]} - SNR: {metrics['snr_db']:.2f} dB, Dynamic Range: {metrics['dynamic_range_db']:.2f} dB, Zero-crossing rate: {metrics['zero_crossing_rate']:.2f} Hz")

    # Polarization analysis (if 3 components available)
    if num_channels >= 4:
        results = perform_polarization_analysis(filtered_data[:, 1], filtered_data[:, 2], filtered_data[:, 3], sample_rate)
        times = [r['time'] for r in results]
        rect_vals = [r['rectilinearity'] for r in results]
        plan_vals = [r['planarity'] for r in results]
        fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
        axes[0].plot(times, rect_vals, 'b-')
        axes[0].set_ylabel('Rectilinearity')
        axes[0].set_title('Polarization Analysis')
        axes[0].grid(True, ls='--', alpha=0.5)
        axes[1].plot(times, plan_vals, 'r-')
        axes[1].set_xlabel('Time [s]')
        axes[1].set_ylabel('Planarity')
        axes[1].grid(True, ls='--', alpha=0.5)
        plt.tight_layout()
        if save_plots:
            fname = f"{station_name}_{network_code}_polarization.png"
            plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
            print(f"Saved plot: {fname}")
        print("Displaying polarization analysis... Close figure to continue.")
        if HAVE_MPLCURSORS:
            mplcursors.cursor(hover=True)
        plt.show()

    # Detailed spectrogram with log frequency scale
    fig, axes = plt.subplots( (num_channels+1)//2, 2, figsize=(12, 6*((num_channels+1)//2)) )
    axes = axes.flatten()
    freq_bands = {'ULF': (0.001, 0.01), 'VLF': (0.01, 0.1), 'LF': (0.1, 1), 'MF': (1, 10), 'HF': (10, 50)}
    for i in range(num_channels):
        ax = axes[i]
        f, t, Sxx = signal.spectrogram(filtered_data[:, i], fs=sample_rate, nperseg=min(512, len(filtered_data)//10), noverlap=256, nfft=min(1024, len(filtered_data)//5))
        Sxx_db = 10 * np.log10(Sxx + 1e-10)
        im = ax.pcolormesh(t, f, Sxx_db, shading='gouraud', cmap='viridis')
        ax.set_yscale('log')
        ax.set_ylim(0.1, 50)
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Freq [Hz]')
        ax.set_title(f'{channel_names[i]} Spectrogram (log f)')
        # Mark frequency bands
        for band, (f_low, f_high) in freq_bands.items():
            if f_low >= ax.get_ylim()[0] and f_high <= ax.get_ylim()[1]:
                ax.axhspan(f_low, f_high, color='gray', alpha=0.1)
                ax.text(t.max()*0.95, np.sqrt(f_low*f_high), band, va='center', ha='right', fontsize=8, backgroundcolor='white')
        plt.colorbar(im, ax=ax, label='PSD [dB]')
    for j in range(num_channels, len(axes)):
        axes[j].axis('off')
    plt.suptitle('Detailed Spectrograms')
    plt.tight_layout()
    if save_plots:
        fname = f"{station_name}_{network_code}_detailed_spectrogram.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying detailed spectrograms... Close figure to continue.")
    plt.show()

    # Statistical analysis of signals
    for i in range(num_channels):
        x = filtered_data[:, i]
        stats_dict = {
            'mean': np.mean(x), 
            'std': np.std(x), 
            'max': np.max(x), 
            'min': np.min(x), 
            'rms': np.sqrt(np.mean(x**2)), 
            'crest_factor': np.max(np.abs(x)) / (np.sqrt(np.mean(x**2)) + 1e-10),
            'kurtosis': stats.kurtosis(x, fisher=False), 
            'skewness': stats.skew(x)
        }
        print(f"\n{channel_names[i]} stats:")
        for k, v in stats_dict.items():
            print(f"  {k}: {v:.4f}")
    # Comparative analysis (distribution & energy over time)
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    # Amplitude distribution
    for i in range(num_channels):
        axes[0].hist(filtered_data[:, i], bins=50, alpha=0.5, density=True, label=channel_names[i])
    axes[0].set_title('Amplitude Distribution')
    axes[0].set_xlabel('Amplitude')
    axes[0].set_ylabel('Density')
    axes[0].legend()
    axes[0].grid(True, ls='--', alpha=0.5)
    # Cumulative energy
    time_axis = np.arange(filtered_data.shape[0]) / sample_rate
    for i in range(num_channels):
        energy = np.cumsum(filtered_data[:, i]**2)
        axes[1].plot(time_axis, energy/energy[-1], label=channel_names[i])
    axes[1].set_title('Cumulative Energy')
    axes[1].set_xlabel('Time [s]')
    axes[1].set_ylabel('Normalized Cumulative Energy')
    axes[1].legend()
    axes[1].grid(True, ls='--', alpha=0.5)
    plt.tight_layout()
    if save_plots:
        fname = f"{station_name}_{network_code}_distribution_energy.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying distribution and energy plots... Close figure to continue.")
    plt.show()

# ---------------------------
# Data cleaning and comparison (optional)
# ---------------------------
if do_cleaning:
    print("\n=== Data Cleaning Pipeline ===")
else:
    print("\nSkipping data cleaning. Using raw data for phase picking/inversion.")
# Define cleaning functions
def remove_instrument_response(data, fs, sensitivity=1.0):
    nyquist = 0.5 * fs
    highpass_freq = 0.01  # remove DC and very low freq
    b, a = signal.butter(2, highpass_freq/nyquist, btype='high')
    return signal.filtfilt(b, a, data) * sensitivity
def remove_trend(data):
    return signal.detrend(data, type='linear')
def remove_mean(data):
    return data - np.mean(data)
def apply_taper(data, taper_percent=0.05):
    n = len(data)
    taper_len = int(n * taper_percent)
    if taper_len < 1:
        return data
    taper = np.ones(n)
    # cosine taper
    t = np.arange(taper_len)
    taper_window = 0.5 * (1 - np.cos(np.pi * t / taper_len))
    taper[:taper_len] = taper_window
    taper[-taper_len:] = taper_window[::-1]
    return data * taper
def remove_outliers_iqr(data, multiplier=1.5):
    Q1 = np.percentile(data, 25)
    Q3 = np.percentile(data, 75)
    IQR = Q3 - Q1
    low = Q1 - multiplier * IQR
    high = Q3 + multiplier * IQR
    cleaned = np.copy(data)
    cleaned[cleaned < low] = low
    cleaned[cleaned > high] = high
    return cleaned
def remove_glitches(data, threshold=5.0, window_size=10):
    cleaned = np.copy(data)
    std = np.std(data)
    med = np.median(data)
    diff = np.abs(data - med)
    glitch_indices = np.where(diff > threshold * std)[0]
    for i in glitch_indices:
        start = max(0, i - window_size)
        end = min(len(data), i + window_size + 1)
        cleaned[i] = np.median(data[start:end])
    return cleaned
def apply_notch_filter(data, fs, notch_freq=50.0, quality_factor=30.0):
    if not notch_freq or notch_freq >= 0.5*fs:
        return data
    try:
        # Newer SciPy usage with fs
        b, a = signal.iirnotch(notch_freq, quality_factor, fs=fs)
    except Exception:
        # Fallback if fs param not supported
        w0 = notch_freq / (0.5*fs)
        if w0 <= 0 or w0 >= 1:
            return data
        b, a = signal.iirnotch(w0, quality_factor)
    return signal.filtfilt(b, a, data)
def apply_bandpass_filter_clean(data, fs, lowcut=0.1, highcut=20.0, order=4):
    nyquist = 0.5 * fs
    if lowcut >= nyquist or highcut >= nyquist:
        return data
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(order, [low, high], btype='band')
    return signal.filtfilt(b, a, data)
def apply_adaptive_filter(data, noise_window=1000):
    # simple noise removal by subtracting running average
    noise_est = uniform_filter1d(data, size=noise_window)
    return data - noise_est

# Apply cleaning steps channel by channel
if do_cleaning:
    cleaned_data = np.zeros_like(data_with_amplitude)
    cleaning_steps = []
    for i in range(num_channels):
        ch_data = data_with_amplitude[:, i].copy()
        # Step 1: remove DC offset
        ch_data = remove_mean(ch_data); step1 = "Removed DC offset"
        # Step 2: remove linear trend
        ch_data = remove_trend(ch_data); step2 = "Removed linear trend"
        # Step 3: taper edges
        ch_data = apply_taper(ch_data); step3 = "Tapered edges"
        # Step 4: remove outliers
        ch_data = remove_outliers_iqr(ch_data); step4 = "Removed outliers (IQR)"
        # Step 5: remove glitches
        ch_data = remove_glitches(ch_data); step5 = "Removed short glitches"
        # Step 6: remove instrument response (simplified highpass)
        ch_data = remove_instrument_response(ch_data, sample_rate); step6 = "Removed instrument response"
        # Step 7: notch filter (power line)
        if notch_freq:
            ch_data = apply_notch_filter(ch_data, sample_rate, notch_freq=notch_freq); step7 = f"Notch filter ({notch_freq}Hz)"
        else:
            step7 = "No notch filter"
        # Step 8: bandpass filter
        ch_data = apply_bandpass_filter_clean(ch_data, sample_rate, lowcut=lowcut_freq, highcut=highcut_freq); step8 = f"Bandpass {lowcut_freq}-{highcut_freq}Hz"
        # Step 9: adaptive noise cancel
        ch_data = apply_adaptive_filter(ch_data, noise_window=1000); step9 = "Adaptive noise cancellation"
        cleaned_data[:, i] = ch_data
        steps = [step1, step2, step3, step4, step5, step6, step7, step8, step9]
        cleaning_steps.append((channel_names[i], steps))
        print(f"{channel_names[i]} cleaned with steps: " + ", ".join(steps))
else:
    # If not cleaning, just set cleaned_data to original truncated data for downstream compatibility
    cleaned_data = data_with_amplitude

# Compare original vs cleaned (if cleaning was done)
if do_cleaning:
    # Time domain comparison
    fig, axes = plt.subplots(num_channels, 2, figsize=(12, 3*num_channels))
    for i in range(num_channels):
        axes[i, 0].plot(np.arange(data_with_amplitude.shape[0])/sample_rate, data_with_amplitude[:, i], 'gray', label='Original', linewidth=0.8)
        axes[i, 0].set_ylabel(channel_units[i])
        axes[i, 0].set_title(f"{channel_names[i]} - Original")
        axes[i, 0].grid(True, ls='--', alpha=0.5)
        axes[i, 1].plot(np.arange(cleaned_data.shape[0])/sample_rate, cleaned_data[:, i], 'b', label='Cleaned', linewidth=0.8)
        axes[i, 1].set_title(f"{channel_names[i]} - Cleaned")
        axes[i, 1].grid(True, ls='--', alpha=0.5)
        if i == num_channels-1:
            axes[i, 0].set_xlabel('Time [s]')
            axes[i, 1].set_xlabel('Time [s]')
    plt.tight_layout()
    if save_plots:
        fname = f"{station_name}_{network_code}_time_compare.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying time domain comparison (original vs cleaned)... Close figure to continue.")
    if HAVE_MPLCURSORS:
        mplcursors.cursor(hover=True)
    plt.show()

    # Frequency domain (PSD) comparison
    plt.figure(figsize=(10, 6))
    for i in range(num_channels):
        f_orig, Pxx_orig = signal.welch(data_with_amplitude[:, i], fs=sample_rate, nperseg=256)
        f_clean, Pxx_clean = signal.welch(cleaned_data[:, i], fs=sample_rate, nperseg=256)
        plt.semilogy(f_orig, Pxx_orig, 'gray', alpha=0.5)
        plt.semilogy(f_clean, Pxx_clean, label=channel_names[i])
    plt.title('PSD Comparison (Original vs Cleaned)')
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Power Spectral Density')
    plt.xlim(0, highcut_freq*1.5)
    plt.grid(True, which='both', ls='--', alpha=0.5)
    plt.legend()
    if save_plots:
        fname = f"{station_name}_{network_code}_PSD_compare.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying PSD comparison... Close figure to continue.")
    plt.show()

    # Spectrogram comparison for one channel (e.g., vertical seismometer if exists)
    ch_to_show = 1 if num_channels > 1 else 0
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    f, t, Sxx_orig = signal.spectrogram(data_with_amplitude[:, ch_to_show], fs=sample_rate, nperseg=256, noverlap=128)
    f, t, Sxx_clean = signal.spectrogram(cleaned_data[:, ch_to_show], fs=sample_rate, nperseg=256, noverlap=128)
    axes[0].pcolormesh(t, f, 10*np.log10(Sxx_orig+1e-10), shading='gouraud', cmap='viridis')
    axes[0].set_title(f"Original {channel_names[ch_to_show]} Spectrogram")
    axes[0].set_ylim(0, min(50, sample_rate/2))
    axes[0].set_ylabel('Frequency [Hz]')
    axes[0].set_xlabel('Time [s]')
    axes[1].pcolormesh(t, f, 10*np.log10(Sxx_clean+1e-10), shading='gouraud', cmap='viridis')
    axes[1].set_title(f"Cleaned {channel_names[ch_to_show]} Spectrogram")
    axes[1].set_ylim(0, min(50, sample_rate/2))
    axes[1].set_xlabel('Time [s]')
    plt.tight_layout()
    if save_plots:
        fname = f"{station_name}_{network_code}_spectrogram_compare.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying spectrogram comparison... Close figure to continue.")
    plt.show()

    # Histograms comparison
    plt.figure(figsize=(8, 6))
    for i in range(num_channels):
        plt.hist(data_with_amplitude[:, i], bins=50, alpha=0.5, label=f"{channel_names[i]} Original", density=True)
        plt.hist(cleaned_data[:, i], bins=50, alpha=0.5, label=f"{channel_names[i]} Cleaned", density=True)
    plt.title('Amplitude Distribution (Original vs Cleaned)')
    plt.xlabel('Amplitude')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.grid(True, ls='--', alpha=0.5)
    if save_plots:
        fname = f"{station_name}_{network_code}_hist_compare.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying amplitude distribution comparison... Close figure to continue.")
    plt.show()

# ---------------------------
# Phase picking (P and S waves) and inversion (optional)
# ---------------------------
# Enhanced automatic P and S phase picker (from vertical component of cleaned data)
from scipy.signal import hilbert, butter, filtfilt, find_peaks

def bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def enhanced_auto_pick_phases(signal_data, time_axis, fs, min_amp=0.1, min_dist=15):
    # Filter for P and S
    filtered_high = bandpass_filter(signal_data, 5, 20, fs)   # high freq for P
    filtered_low  = bandpass_filter(signal_data, 0.5, 5, fs)  # lower freq for S
    env_high = np.abs(hilbert(filtered_high))
    env_low  = np.abs(hilbert(filtered_low))
    env_high /= (np.max(env_high) + 1e-9)
    env_low  /= (np.max(env_low) + 1e-9)
    p_peaks, _ = find_peaks(env_high, height=min_amp*1.2, distance=min_dist)
    s_peaks, _ = find_peaks(env_low, height=min_amp, distance=min_dist*2)
    phases = []
    for idx in p_peaks:
        t = time_axis[idx]
        # Consider only early part for P
        if t < time_axis[-1] * 0.7:
            phases.append({'time': t, 'amplitude': env_high[idx], 'type': 'P', 'index': idx})
    p_times = [ph['time'] for ph in phases if ph['type'] == 'P']
    for idx in s_peaks:
        t = time_axis[idx]
        if t > time_axis[-1] * 0.2:  # S waves arrive later
            # ensure not already a P within 2s (to avoid double picking same arrival)
            if all(abs(t - tp) > 2.0 for tp in p_times):
                phases.append({'time': t, 'amplitude': env_low[idx], 'type': 'S', 'index': idx})
    # Sort phases by time
    phases.sort(key=lambda x: x['time'])
    return phases

# If no seismometer channels present, skip picking
if not do_inversion and not pick_p and not pick_s:
    print("Skipping phase picking and inversion as per user choice.")
else:
    # Use vertical seismometer (channel 1 in our naming) if available, otherwise first channel
    vertical_index = 1 if num_channels > 1 else 0
    seismogram = cleaned_data[:, vertical_index]  # use cleaned vertical channel
    time_axis = np.arange(len(seismogram)) / sample_rate
    print("Picking seismic phases (P and S) from the data...")
    phases = enhanced_auto_pick_phases(seismogram, time_axis, sample_rate, min_amp=0.15, min_dist=int(25))
    # Filter phases list based on user choice
    if not pick_p:
        phases = [ph for ph in phases if ph['type'] != 'P']
    if not pick_s:
        phases = [ph for ph in phases if ph['type'] != 'S']
    p_times = [ph['time'] for ph in phases if ph['type'] == 'P']
    s_times = [ph['time'] for ph in phases if ph['type'] == 'S']
    print(f"Detected {len(p_times)} P-wave picks and {len(s_times)} S-wave picks.")
    # Plot the seismogram with phase pick markers
    plt.figure(figsize=(12, 4))
    plt.plot(time_axis, seismogram, 'k-', label='Seismogram')
    for ph in phases:
        color = 'r' if ph['type']=='P' else 'g'
        plt.axvline(ph['time'], color=color, linestyle='--', linewidth=1)
        plt.text(ph['time'], np.max(seismogram)*0.8, ph['type'], color=color, fontsize=10, rotation=90, va='bottom')
    plt.title('Seismogram with Picked Phases')
    plt.xlabel('Time [s]')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.grid(True, ls='--', alpha=0.5)
    if save_plots:
        fname = f"{station_name}_{network_code}_seismogram_picks.png"
        plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
        print(f"Saved plot: {fname}")
    print("Displaying seismogram with phase picks... Close figure to continue.")
    if HAVE_MPLCURSORS:
        mplcursors.cursor(hover=True)
    plt.show()

    # If inversion is not requested, we stop here after picking
    if not do_inversion:
        print("Phase picking complete. Inversion was not requested. Processing finished.")
# ---------------------------
# Travel-time inversion for layered structure (optional)
# ---------------------------
if do_inversion:
    # Define travel-time calculation for given velocity model (P or S)
    def calculate_travel_times(model_velocities, model_depths, offsets, wave_type='P'):
        times = np.zeros(len(offsets))
        # Use Vp/Vs ratio ~1.73 if S-wave
        vel = model_velocities.copy()
        if wave_type.upper() == 'S':
            vel = [v/1.73 for v in vel]
        for j, x in enumerate(offsets):
            total_time = 0.0
            depth_top = 0.0
            angle = 0.0
            for layer in range(len(model_velocities)):
                # layer thickness
                if layer < len(model_depths):
                    thickness = model_depths[layer] - depth_top
                else:
                    thickness = 50.0  # extend last layer if needed
                # horizontal distance that can be traveled in this layer
                horiz = thickness * np.tan(angle) if angle != 0 else 0.0
                if horiz + 1e-6 >= x:  # if the ray exits in this layer
                    # remaining distance portion
                    if angle == 0:
                        travel = x / (vel[layer] + 1e-9)
                    else:
                        travel = (x - horiz) / (vel[layer] * np.cos(angle) + 1e-9)
                    total_time += travel
                    break
                # full layer traverse
                travel = thickness / (vel[layer] * np.cos(angle) + 1e-9) if angle != 0 else thickness / (vel[layer] + 1e-9)
                total_time += travel
                x -= horiz  # reduce remaining horizontal distance
                depth_top = model_depths[layer] if layer < len(model_depths) else depth_top + thickness
                # Recalculate angle for next layer using Snell's law
                if layer < len(model_velocities) - 1:
                    try:
                        angle = np.arcsin((vel[layer] / vel[layer+1]) * np.sin(angle))
                    except ValueError:
                        angle = np.pi/2  # critical angle exceeded, horizontal propagation
            times[j] = total_time
        return times

    # Joint inversion combining P and S travel times
    def joint_inversion(p_obs, s_obs, offsets, init_vel, init_depth, max_iter=20, damping=0.2, tol=1e-6):
        vel = init_vel.copy()
        dep = init_depth.copy()
        history = {'misfit': []}
        # Build combined observed vector
        obs = np.concatenate([p_obs, s_obs])
        n_p = len(p_obs)
        for it in range(max_iter):
            # Predicted times
            p_pred = calculate_travel_times(vel, dep, offsets, 'P')
            s_pred = calculate_travel_times(vel, dep, offsets[:len(s_obs)], 'S') if len(s_obs) > 0 else np.array([])
            pred = np.concatenate([p_pred, s_pred])
            res = obs - pred
            # Weighted least squares (weight S residuals higher if present)
            weights = np.ones_like(res)
            if len(s_obs) > 0:
                weights[n_p:] = 1.5  # weight S
            misfit = np.sqrt(np.mean((res * weights)**2))
            history['misfit'].append(misfit)
            if misfit < tol:
                print(f"Converged in {it} iterations, misfit={misfit:.6f}")
                break
            # Simple Jacobian: finite differences approximation for each layer velocity and depth
            n_layers = len(vel)
            J = np.zeros((len(pred), 2*n_layers-1))  # velocity and depth (depth has one fewer)
            delta = 0.01
            for j in range(n_layers):
                # Velocity partial derivative
                vel_temp = vel.copy()
                vel_temp[j] += delta
                p_temp = calculate_travel_times(vel_temp, dep, offsets, 'P')
                s_temp = calculate_travel_times(vel_temp, dep, offsets[:len(s_obs)], 'S') if len(s_obs) > 0 else np.array([])
                pred_temp = np.concatenate([p_temp, s_temp])
                J[:, j] = (pred_temp - pred) / delta
                # Depth partial derivative (except last depth which is infinite)
                if j < len(dep):
                    dep_temp = dep.copy()
                    dep_temp[j] += delta
                    p_temp = calculate_travel_times(vel, dep_temp, offsets, 'P')
                    s_temp = calculate_travel_times(vel, dep_temp, offsets[:len(s_obs)], 'S') if len(s_obs) > 0 else np.array([])
                    pred_temp = np.concatenate([p_temp, s_temp])
                    J[:, n_layers + j] = (pred_temp - pred) / delta
            # Damped least squares solution
            JTJ = J.T @ (J * weights[:, None])  # apply weights to J rows
            JTr = J.T @ (res * weights)
            # Solve (JTJ + damping*I) * update = J^T * residuals
            try:
                update = np.linalg.solve(JTJ + damping*np.eye(JTJ.shape[0]), JTr)
            except np.linalg.LinAlgError:
                update = np.linalg.lstsq(JTJ + damping*np.eye(JTJ.shape[0]), JTr, rcond=None)[0]
            # Apply update (constrain to reasonable ranges)
            for j in range(n_layers):
                vel[j] += update[j]
                vel[j] = max(1.0, min(vel[j], 8.5))
            for j in range(len(dep)):
                dep[j] += update[n_layers + j]
                dep[j] = max(0.1, dep[j])
            dep = np.sort(dep)  # ensure depths in ascending order
        return vel, dep, history

    # Prepare observed arrival times for inversion
    if len(p_times) == 0:
        print("No P-wave picks found; inversion cannot proceed.")
        do_inversion = False
    else:
        # Create synthetic offsets array (since actual source-receiver distances not provided, assume increasing offsets)
        n_offsets = max(len(p_times), len(s_times), 8)  # ensure enough offsets
        offsets = np.linspace(5, 5*n_offsets, n_offsets)  # example offset values in km
        observed_p_times = np.array(p_times[:len(offsets)])
        observed_s_times = np.array(s_times[:len(offsets)]) if pick_s else np.array([])
        if len(observed_p_times) < len(offsets):
            offsets = offsets[:len(observed_p_times)]
        if len(observed_s_times) > len(offsets):
            observed_s_times = observed_s_times[:len(offsets)]

        # Initial model for inversion (simple crust model)
        initial_vel = np.array([3.0, 5.5, 6.5, 7.5])  # km/s per layer
        initial_depth = np.array([5.0, 15.0, 30.0])   # depth of interfaces (km)
        print("Starting inversion with initial model:")
        print(f"  Velocities = {initial_vel} km/s")
        print(f"  Depths = {initial_depth} km")
        if len(observed_s_times) < 2:
            # Not enough S picks for joint inversion, use P only
            print("Not enough S-wave picks, performing P-wave only inversion.")
            observed_s_times = np.array([])

        # Perform joint inversion (if S picks exist) or P-only inversion
        final_vel, final_depth, inv_history = joint_inversion(observed_p_times, observed_s_times, offsets, initial_vel, initial_depth, max_iter=25, damping=0.2, tol=1e-4)
        print("\n=== Inversion Results ===")
        print(f"Layer velocities (km/s): {np.round(final_vel, 2)}")
        print(f"Layer depths (km): {np.round(final_depth, 2)}")
        if len(inv_history['misfit']) > 0:
            print(f"Final misfit: {inv_history['misfit'][-1]:.6f} s")

        # Plot observed vs calculated travel times
        calc_p_times = calculate_travel_times(final_vel, final_depth, offsets, 'P')
        calc_s_times = calculate_travel_times(final_vel, final_depth, offsets[:len(observed_s_times)], 'S') if len(observed_s_times)>0 else None
        plt.figure(figsize=(6,4))
        plt.plot(offsets, observed_p_times, 'ro', label='Observed P')
        plt.plot(offsets, calc_p_times, 'r--', label='Calc P')
        if len(observed_s_times) > 0:
            off_s = offsets[:len(observed_s_times)]
            plt.plot(off_s, observed_s_times, 'go', label='Observed S')
            plt.plot(off_s, calc_s_times, 'g--', label='Calc S')
        plt.xlabel('Offset [km]')
        plt.ylabel('Travel Time [s]')
        plt.title('Travel Time Fit')
        plt.legend()
        plt.grid(True, ls='--', alpha=0.5)
        if save_plots:
            fname = f"{station_name}_{network_code}_travel_time_fit.png"
            plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
            print(f"Saved plot: {fname}")
        print("Displaying travel-time fit plot... Close figure to finish.")
        if HAVE_MPLCURSORS:
            mplcursors.cursor(hover=True)
        plt.show()
        print("Inversion complete. Processing finished.")

Note: Install 'mplcursors' for interactive hover annotations on plots.
