# Applying the U-Net to real seismic data (detections the hard way)

**Author:** Amanda M. Thomas  
**Goal:** In this notebook we will once again build the U-Net but this time instead of training we simply load in the model weights we saved after training.  We then apply the trained model to data from NCEDC stations that recorded the 2022 Ferndale earthquake.


In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchinfo import summary

## 1. Build the U-net

In [63]:
# Cell 2: U-Net Building Blocks
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv(x)

class UNet1D(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[16, 32, 64, 128]):
        super().__init__()
        
        self.downs = nn.ModuleList()  # Encoder blocks (downsampling path)
        self.ups = nn.ModuleList()    # Decoder blocks (upsampling path)
    
        # ----- Encoder: Downsampling Path -----
        # Each ConvBlock halves the temporal resolution via pooling (done in forward),
        # and increases the number of feature channels.
        for feat in features:
            self.downs.append(ConvBlock(in_channels, feat))  # ConvBlock: Conv + ReLU + Conv + ReLU
            in_channels = feat  # Update in_channels for the next block
    
        # ----- Bottleneck -----
        # Deepest layer in the U-Net, connects encoder and decoder
        self.bottleneck = ConvBlock(features[-1], features[-1]*2)
    
        # ----- Decoder: Upsampling Path -----
        # Reverse features list for symmetrical decoder
        rev_feats = features[::-1]
        for feat in rev_feats:
            # First upsample (via transposed convolution)
            self.ups.append(
                nn.ConvTranspose1d(feat*2, feat, kernel_size=2, stride=2)
            )
            # Then apply ConvBlock: input has double channels due to skip connection
            self.ups.append(ConvBlock(feat*2, feat))
    
        # ----- Final Output Convolution -----
        # 1x1 convolution to map to desired output channels (e.g., P, S, noise)
        self.final_conv = nn.Conv1d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = F.max_pool1d(x, kernel_size=2)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_conn = skip_connections[idx//2]
            if x.shape[-1] != skip_conn.shape[-1]:
                x = F.pad(x, (0, skip_conn.shape[-1] - x.shape[-1]))
            x = torch.cat((skip_conn, x), dim=1)
            x = self.ups[idx+1](x)
        x = self.final_conv(x)
        return F.softmax(x, dim=1)

# 2. Load weights into the model

This cell loads the saved weights into the model.

In [64]:
model = UNet1D()
model.load_state_dict(torch.load("../Loic/UNet/model_weights_eq_only_v2.pt",weights_only=True, map_location=torch.device('cpu')))
model.eval() 

UNet1D(
  (downs): ModuleList(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv1d(3, 16, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): ReLU()
        (2): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,))
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): ReLU()
        (2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (conv): Sequential(
        (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): ReLU()
        (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
        (3): ReLU()
      )
    )
    (3): ConvBlock(
      (conv): Sequential(
        (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): ReLU()
        (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
        (3): ReLU

# 3. Load and preprocess MiniSEED waveform data

The most important thing here is to make sure that the input data is similar to the training set.  This includes insuring that the sample rates are the same and the same normalizations applied to the preprocessing steps are applied to the actual data.  This includes sample rates, normalizations, and component order.



In [65]:
# Load and preprocess MiniSEED waveform data
import numpy as np
from obspy import read, Stream, Trace

# the PNW dataset is in ENZ order, if only Z is present, we need to create dummy E and N traces
from obspy import Stream, Trace
import numpy as np

def ensure_ENZ_order(st):
    """
    Ensure a stream has channels in E, N, Z order.
    Handles both E/N/Z and 1/2/Z naming conventions.
    If only Z is present, creates dummy E and N traces filled with zeros.
    """
    # Normalize component keys (map 1 -> E, 2 -> N, Z stays Z)
    comp_map = {}
    for tr in st:
        suffix = tr.stats.channel[-1].upper()
        if suffix == "1":
            comp_map["E"] = tr
        elif suffix == "2":
            comp_map["N"] = tr
        else:
            comp_map[suffix] = tr

    print("Detected components:", comp_map.keys())

    # If only Z is present, create dummy E and N
    if set(comp_map.keys()) == {"Z"}:
        z_trace = comp_map["Z"]
        npts = z_trace.stats.npts
        sampling_rate = z_trace.stats.sampling_rate
        starttime = z_trace.stats.starttime

        # Create dummy E and N traces
        e_trace = Trace(data=np.zeros(npts, dtype=np.float32), header=z_trace.stats.copy())
        n_trace = Trace(data=np.zeros(npts, dtype=np.float32), header=z_trace.stats.copy())

        e_trace.stats.channel = z_trace.stats.channel[:-1] + "E"
        n_trace.stats.channel = z_trace.stats.channel[:-1] + "N"

        return Stream(traces=[e_trace, n_trace, z_trace])

    # Reorder to E, N, Z (if available)
    ordered = []
    for comp in ["E", "N", "Z"]:
        if comp in comp_map:
            ordered.append(comp_map[comp])

    return Stream(traces=ordered)



# Load the data
st = read('../data/B047.PB.EH*')
# st = read('../../../shared/shortcourses/crescent_ml_2025/miniseed/B047*')
# Merge streams
st.merge(fill_value='interpolate')
# Ensure ENZ order
st = ensure_ENZ_order(st)
# Demean
st.detrend('demean')
# Resample to 100 Hz
for tr in st:
    if tr.stats.sampling_rate != 100:
        print(f"Resampling {tr.id} from {tr.stats.sampling_rate} Hz to {target_rate} Hz")
        tr.resample(target_rate)

# Your input data
data = np.stack([tr.data for tr in st], axis=0)  # shape: (channels, time)
chunk_size = 3001
num_channels, total_length = data.shape

# Calculate how much padding is needed
remainder = total_length % chunk_size
if remainder > 0:
    pad_width = chunk_size - remainder
    # Pad with zeros at the end along the time axis
    data = np.pad(data, ((0, 0), (0, pad_width)), mode='constant')

# Reshape to (num_chunks, 3, 3001)
num_chunks = data.shape[1] // chunk_size
chunks = data.reshape(num_channels, num_chunks, chunk_size).transpose(1, 0, 2)  # shape: (num_chunks, 3, 3001)

# Demean each chunk along the time axis
means = np.mean(chunks, axis=2, keepdims=True)  # shape: (num_chunks, 3, 1)
demeaned_chunks = chunks - means

# Normalize by the peak absolute value 
max_vals = np.max(np.abs(demeaned_chunks), axis=2, keepdims=True)  # shape: (num_chunks, 3, 1)
normalized_chunks = np.divide(
    demeaned_chunks,
    max_vals,
    out=np.zeros_like(demeaned_chunks),
    where=max_vals != 0
).astype(np.float32)

# Convert to torch tensor
input_tensor = torch.tensor(normalized_chunks, dtype=torch.float32, device=torch.device('cpu'))  # shape: (batch_size, 3, 3001)

Detected components: dict_keys(['E', 'N', 'Z'])


In [66]:
# Run inference
with torch.no_grad():
    pred = model(input_tensor).squeeze(0).cpu().numpy()  # shape: (2, T)

In [67]:
# Step 1: Reshape to (2, total_length_with_padding)
input_unwrapped = input_tensor.permute(1, 0, 2).reshape(3, -1)[:, :8640000]   # Shape: (3, num_chunks * 3001)
pred_unwrapped = pred.transpose(1, 0, 2).reshape(3, -1)[:, :8640000]   # shape: (2, num_chunks * 3001)

# 4. Write pick file

In [68]:
# Output file
output_file = 'detections_%s.%s.%s.%s.csv'%(tr.stats.network,tr.stats.station,tr.stats.starttime.year,tr.stats.starttime.julday)

import csv
from obspy import UTCDateTime
from scipy.signal import find_peaks

# Example settings (replace with actual values)
threshold = 0.1
starttime = tr.stats.starttime  # trace start time

# Write header
with open(output_file, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['phase', 'time', 'confidence'])

    for phase_idx, phase_label in enumerate(['P', 'S']):
        probs = pred_unwrapped[phase_idx, :]

        # Find peaks in the probability curve
        peaks, properties = find_peaks(probs, height=threshold, distance=tr.stats.sampling_rate)  # e.g. at least 1s apart

        for peak_idx in peaks:
            peak_time = starttime + peak_idx / tr.stats.sampling_rate
            confidence = probs[peak_idx]
            writer.writerow([phase_label, peak_time,f'{confidence:.3f}'])

! head detections*

phase,time,confidence
P,2022-12-20T00:01:40.038400Z,0.327
P,2022-12-20T00:01:41.648400Z,0.789
P,2022-12-20T00:01:45.568400Z,0.642
P,2022-12-20T00:01:50.978400Z,0.306
P,2022-12-20T00:01:55.968400Z,0.106
P,2022-12-20T00:02:06.428400Z,0.342
P,2022-12-20T00:02:20.758400Z,0.108
P,2022-12-20T00:02:23.068400Z,0.292
P,2022-12-20T00:02:30.058400Z,0.279


# 5. Plot the predictions

In [71]:
picks

Unnamed: 0,index,phase,time,confidence
0,0,P,2022-12-20T00:01:40.038400Z,0.327
1,1,P,2022-12-20T00:01:41.648400Z,0.789
2,2,P,2022-12-20T00:01:45.568400Z,0.642
3,3,P,2022-12-20T00:01:50.978400Z,0.306
4,4,P,2022-12-20T00:01:55.968400Z,0.106
...,...,...,...,...
4005,4005,S,2022-12-20T23:59:36.108400Z,0.445
4006,4006,S,2022-12-20T23:59:45.758400Z,0.241
4007,4007,S,2022-12-20T23:59:49.668400Z,0.946
4008,4008,S,2022-12-20T23:59:53.888400Z,0.205


In [89]:
import pandas as pd

# Apply bandpass filter to each trace in-place
st.filter('bandpass', freqmin=1.0, freqmax=20.0, corners=4, zerophase=True)
plot_data = np.stack([tr.data for tr in st], axis=0)  # shape: (channels, time)

# Assume 100 Hz sampling rate; adjust as needed
sampling_rate = st[0].stats.sampling_rate
start_time = st[0].stats.starttime

# load pick file
picks=pd.read_csv(output_file)

# Get max amplitude in Z-channel
maxind = np.argmax(plot_data[2, :])
winlen = 6000  # window length in samples

for ii in np.arange(maxind // winlen - 1, maxind // winlen + 9):
    istart = ii * winlen
    ifinish = istart + winlen
    if istart < 0 or ifinish > plot_data.shape[1]:
        continue

    # Compute timestamps for the window
    window_start_time = start_time + istart / sampling_rate
    window_end_time = start_time + ifinish / sampling_rate

    # Filter picks by time window
    mask = (picks['time'] >= window_start_time.datetime) & (picks['time'] < window_end_time.datetime)
    ppicks_filtered = picks[mask & (picks['phase'] == 'P')]
    spicks_filtered = picks[mask & (picks['phase'] == 'S')]

    # Convert pick times to sample indices
    ppick_indices = ((ppicks_filtered['time'] - start_time.datetime).dt.total_seconds() * sampling_rate).astype(int)
    spick_indices = ((spicks_filtered['time'] - start_time.datetime).dt.total_seconds() * sampling_rate).astype(int)

    # Plotting
    fig, axs = plt.subplots(2, 1, figsize=(15, 6), sharex=True)
    
    axs[0].set_title('Input Seismic Data')
    axs[0].plot(plot_data[0, istart:ifinish], label='E-component', color='tab:red')
    axs[0].plot(plot_data[1, istart:ifinish], label='N-component', color='tab:grey')
    axs[0].plot(plot_data[2, istart:ifinish], label='Z-component', color='tab:blue')
    axs[0].legend()

    axs[1].set_title('Prediction')
    # If you have predictions:
    # axs[1].plot(pred_unwrapped[0, istart:ifinish], label='P-phase', color='tab:red')
    # axs[1].plot(pred_unwrapped[1, istart:ifinish], label='S-phase', color='tab:blue')
    # axs[1].plot(pred_unwrapped[2, istart:ifinish], label='Noise', color='tab:grey')

    axs[1].plot(ppick_indices - istart, ppicks_filtered['confidence'], 'o', color='tab:red', label='P-pick')
    axs[1].plot(spick_indices - istart, spicks_filtered['confidence'], 'o', color='tab:blue', label='S-pick')
    axs[1].legend()

    plt.tight_layout()
    plt.show()


TypeError: '>=' not supported between instances of 'str' and 'datetime.datetime'

Bad pipe message: %s [b'\x13\xcb\x057\x9b\xa8\xde\x00\x95\x15w<\xa7&\xfb\xf4\x81\x98 \xb2\xb0\xa8\x00\xf4\xa5\x0c\x85F\xfaI\x15\xdc]+\xe3\xd3Ud\xf6\xf5f\xcd$']
Bad pipe message: %s [b'gc\x95\xfa\xcbQ\x98\xfa\x10\x19-\xaeF\xb8o\x19"B !W1\x8d|\x1b\xdd\x13\xf9\x1c\xd9\xdd\xd4\xb9\x10\xcc\x83\x99 \x9c\xeb\x8d)\xa6!\xf6\xcb\xb5n\x00_b\x00 \x1a\x1a\x13\x01\x13\x02\x13\x03\xc0+\xc0/\xc0,\xc00\xcc\xa9\xcc\xa8\xc0\x13\xc0\x14\x00\x9c\x00\x9d\x00/\x005\x01\x00\x06\x93\n\n\x00\x00\x00\x05\x00\x05\x01\x00\x00\x00\x00\xfe\r\x01\x1a\x00\x00\x01\x00\x01\xef\x00 F\xce\x0c\x1f\x8eD\xec\xb5\x9aZ\xfa\xcfj\xd0L\xefI\xc2\xb3\x8e\xc4\x15\xc9\x9f\xa0\xe5\xce\x8c \xf9\xeat\x00\xf0\xfa\xaa\x00p\xb3A*\x03\xffyHGt\x94\xc2\x1e\xbc5\x8df\xd3kIIA\xbekA\xf2o\x021L,)\x82\x95hX $\xac\xe0\xa6\x81\rv\x81\x18\xd9\\\r\x85\x95qW\xad\xd4\x86\x05!+\x84\xb6m\xc5\xf0<JI\x03\xf6O\x9dj\xbb\x8c\xbe\xdc\xf9\x02H\xf5\xea\xf8\xb4\x0e']
Bad pipe message: %s [b'\xd1\x9f5\x1aL\x1e\xa2f\xaa\x87E\x11\x90\xe8E\xa3P\xa4\xb9\x97\xdc\')\x89\

In [82]:
start

76062.09

In [88]:
picks.loc[0,'time']


Timestamp('2022-12-20 00:01:40.038400+0000', tz='UTC')

Bad pipe message: %s [b'\x03\xac\xa2\xf2"\xb2\xbf\xd6P\xff-\x19\xe3\xe5X\x92U\xa1 \xa8h\x1d!\x139`n\x1d\xed^ \xc6\xa2\xe2\xb2Rd\x8e\x12_8\x8a\xa1\xf2\xd8z\x86Pb\xb8\x82\x00 ::\x13\x01\x13\x02\x13\x03\xc0+\xc0/\xc0,\xc00\xcc\xa9\xcc\xa8\xc0\x13\xc0\x14\x00\x9c\x00\x9d\x00/\x005\x01\x00\x06S\xba\xba\x00\x00\x00+\x00\x07\x06\xda\xda\x03\x04\x03\x03\x00#\x00\x00\x00\n\x00\x0c\x00\n**\x11\xec\x00\x1d\x00\x17\x00\x18\x00\x1b\x00\x03\x02\x00\x02\x00\r\x00\x12\x00\x10\x04\x03\x08\x04\x04\x01\x05\x03\x08\x05\x05\x01\x08\x06\x06\x01\xfe\r\x00\xda\x00\x00\x01\x00\x01V\x00 % \xbd\x19\xa2\x94l\xc0\x0eUQ\x88\x89\xe3o$\xcc\x19\xbc\x05\x00\xee*\xdf\xba1\x02\xca\xda\x86\x89B\x00\xb0\x1c', b'%\xe4\xc9T\xa3\xa1\xdb\xa6\x15\xfe\xfa\x9d\xd8p\xd5\xea#\x11i\x99?7\xc4\x94\xc6\xa9\xca\xc6\x08IE/T\xac\x15\xd4\xa6\xe5$\xdb\xda\xf9.X[\xdc\xb6\x0b\x08\xa1`n\xf8\x89\x80\x03dQ\xa6\x1f\xc2R\x0e\xc8C\xfe,\xc3\xff\x8d\xc5\xe3\xb2\xf2K\xe1[\x80=\xc7\x00\x85Y\x97.B1\xc5\xa6\xce{\xbe\x19\xc6\xdc\xf1)L\xa1\xe9\x8e\xeb\xde\