In [None]:
import cv2  # OpenCV for resizing
import numpy as np
import tensorflow as tf
from scipy.signal import spectrogram
from obspy import read  # Assuming we are using Obspy to read mseed files
import matplotlib.pyplot as plt

# Function to generate a spectrogram and resize it
def generate_spectrogram(signal, fs, window_size=5, overlap=0.8, target_size=(224, 224)):
    """
    Generate a spectrogram and resize it.

    :param signal: 1D array of the seismic signal.
    :param fs: Sampling frequency.
    :param window_size: Window size for the spectrogram.
    :param overlap: Overlap ratio for the spectrogram windows.
    :param target_size: The size to resize the spectrogram to.
    :return: Frequency array, time array, resized spectrogram, original spectrogram.
    """
    f, t, Sxx = spectrogram(signal, fs, nperseg=int(fs * window_size), noverlap=int(fs * overlap))
    Sxx = np.log1p(Sxx)  # Log scaling for visualization

    # Resize the spectrogram to (224, 224)
    Sxx_resized = cv2.resize(Sxx, target_size, interpolation=cv2.INTER_AREA)

    return f, t, Sxx_resized, Sxx

# Load the saved model
model = tf.keras.models.load_model('p_wave_cnn_model.h5')

# Function to map predicted pixel back to time axis (original)
def map_pixel_to_time(pred_pixel, original_t, target_size=224):
    original_size = len(original_t)
    scale_factor = original_size / target_size  # Scale from resized back to original
    original_pixel = int(pred_pixel * scale_factor)

    # Ensure the pixel is within bounds
    if original_pixel >= original_size:
        original_pixel = original_size - 1

    return original_t[original_pixel]

# Inference function to predict P-wave start time from mseed file
def predict_p_wave_start_time(mseed_file, model):
    """
    Predict the P-wave start time from a .mseed file using the trained CNN model.

    :param mseed_file: Path to the .mseed file.
    :param model: The loaded CNN model for predicting P-wave pixel.
    :return: Predicted P-wave start time in seconds.
    """
    # Read the mseed file using Obspy
    st = read(mseed_file)
    signal = st[0].data  # Extract signal (assuming single trace)
    fs = st[0].stats.sampling_rate  # Extract sampling rate

    # Generate spectrogram
    f, t, Sxx_resized, original_Sxx = generate_spectrogram(signal, fs)

    # Prepare the spectrogram for CNN input
    X_input = np.expand_dims(Sxx_resized, axis=0)  # Add batch dimension
    X_input = np.expand_dims(X_input, axis=-1)  # Add channel dimension

    # Predict the pixel index corresponding to P-wave start
    pred_pixel = model.predict(X_input)[0][0]  # Predict and extract the pixel index

    # Map the predicted pixel back to the original time axis
    pred_time = map_pixel_to_time(pred_pixel, t, target_size=224)

    print(f"Predicted P-wave start time: {pred_time:.2f} seconds")
    return pred_time

# Visualize the predicted and expected P-wave start time on the spectrogram
def plot_predicted_p_wave(mseed_file, pred_pixel, pred_time, model):
    """
    Plot the spectrogram with the predicted P-wave start time (pixel and time).

    :param mseed_file: Path to the .mseed file.
    :param pred_pixel: The predicted pixel index in the resized spectrogram.
    :param pred_time: The predicted P-wave start time in seconds.
    :param model: The loaded CNN model.
    """
    # Read the mseed file using Obspy
    st = read(mseed_file)
    signal = st[0].data
    fs = st[0].stats.sampling_rate

    # Generate spectrogram
    f, t, Sxx_resized, original_Sxx = generate_spectrogram(signal, fs)

    # Plot the spectrogram
    plt.figure(figsize=(10, 6))
    plt.pcolormesh(np.arange(224), np.linspace(f[0], f[-1], 224), Sxx_resized, shading='gouraud')

    # Plot predicted P-wave pixel
    plt.axvline(pred_pixel, color='g', linestyle='--', label=f'Predicted P-wave pixel: {int(pred_pixel)}')

    plt.colorbar(label='Log scaled amplitude')
    plt.xlabel('Time (pixels)')
    plt.ylabel('Frequency (Hz)')
    plt.legend()
    plt.title(f'Predicted P-wave Start Time: {pred_time:.2f} seconds')
    plt.show()

# Example inference
mseed_file = 'path_to_your_mseed_file.mseed'
pred_time = predict_p_wave_start_time(mseed_file, model)

# Visualize the prediction
pred_pixel = model.predict(np.expand_dims(Sxx_resized, axis=(0, -1)))[0][0]  # Get predicted pixel
plot_predicted_p_wave(mseed_file, pred_pixel, pred_time, model)
