We are interested in the [STEAD](https://github.com/smousavi05/STEAD) dataset.

The STEAD format refers to the STanford EArthquake Dataset, which is a large-scale dataset of seismic waveforms designed for machine learning applications. It contains millions of labeled seismic waveforms, including earthquakes and noise records, with metadata such as event magnitude, depth, station location, and more.

You can download each chunk from the [Github](https://github.com/smousavi05/STEAD?tab=readme-ov-file) page of the STEAD dataset

Each chunk consists of an HDF5 file (containing waveform data) and a CSV file (containing metadata).

* Note1: some of the unzipper programs for Windows and Linux operating systems have size limits. Try '7Zip' software if had problems unzipping the files.

* Note2: all the metadata are also available in the hdf5 file (as attributes associated with each waveform). But the CSV file can be used to easily select a specific part of the dataset and only read associated waveforms from the hdf5 file for efficiency.

* Note3: For some of the noise data waveforms are identical for 3 components. These are related to single-channel stations where we duplicated the vertical channel for horizontal ones. However, these makeup to less than 4 % of noise data. For the rest, noise is different for each channel.

Before proceeding is useful to understand what is the structure of a hdf5 file... the basic components are:

1. 📂 **Groups (Like Folders in a File System)**: Groups are containers that organize datasets and other groups (like directories in a filesystem). The root group ` / ` is the top-level container in the HDF5 file.
Example:
```
/
├── metadata
├── data
│   ├── event1
│   ├── event2
│   ├── event3
```

2. 📄 **Datasets (Like Files in a Folder)**: Datasets contain the actual numerical or text data (like files in a folder). Datasets can be multi-dimensional (like NumPy arrays).
Example:
```
data/event1  →  [1000x3] array (1000 samples, 3 components)
```

3. 📝 **Attributes (Metadata Associated with Groups or Datasets)**: Attributes store metadata about groups or datasets (like file properties).
Example:
```
data/event1.attrs
├── p_arrival_sample: 230
├── s_arrival_sample: 450
├── coda_end_sample: 800
├── station_name: ABC
├── magnitude: 3.5
```
Attributes cannot be datasets but are small pieces of metadata (e.g., timestamps, location, experiment details)

The structure of a chunk of the STEAD dataset looks like 
```
/
├── data (group)
│   ├── event1 (dataset, 1000x3 array)
│   │   ├── p_arrival_sample: 230
│   │   ├── s_arrival_sample: 450
│   │   ├── coda_end_sample: 800
│   │   ├── source_magnitude: 3.5
│   │   ├── source_distance_km: 15.2
│   │   ├── ...
│   │
│   ├── event2 (dataset, 1200x3 array)
│   │   ├── p_arrival_sample: 250
│   │   ├── s_arrival_sample: 470
│   │   ├── ...
│   │
│   ├── ...

```

For any further understanding of the dataset please visit the official GitHub page and the [STanford EArthquake Dataset (STEAD): A Global Data Set of Seismic Signals for AI](https://www.researchgate.net/publication/336598670_STanford_EArthquake_Dataset_STEAD_A_Global_Data_Set_of_Seismic_Signals_for_AI) article

# hdf5 to Torch_ensor

In [None]:
! pip install obspy

In [None]:
import pandas as pd
import h5py
import numpy as np
import obspy
import torch
from obspy import UTCDateTime
from obspy.clients.fdsn import Client

file_name = "path/to/STEAD_data/chunk2.hdf5"
csv_file = "path/to/STEAD_data/chunk2/chunk2.csv"

# reading the csv file into a dataframe:
df = pd.read_csv(csv_file)
print(f'total events in csv file: {len(df)}')
df = df[~df.network_code.isin(['IV', 'HA', 'KO', 'HP', 'FR', 'S', 'TU'])]

# making a list of trace names for the selected data
ev_list = df['trace_name'].to_list()
print(len(ev_list))

  df = pd.read_csv(csv_file)


total events in csv file: 200000
198724


In [None]:
#custom to deal with 12Z casescases

def make_stream(dataset):
    '''
    input: hdf5 dataset
    output: obspy stream

    '''
    data = np.array(dataset)

    tr_E = obspy.Trace(data=data[:, 0])
    tr_E.stats.starttime = UTCDateTime(dataset.attrs['trace_start_time'])
    tr_E.stats.delta = 0.01
    if(dataset.attrs['network_code'] == 'PB'):
        tr_E.stats.channel = dataset.attrs['receiver_type']+'1'
    else:
        tr_E.stats.channel = dataset.attrs['receiver_type']+'E'
    tr_E.stats.station = dataset.attrs['receiver_code']
    tr_E.stats.network = dataset.attrs['network_code']
    if(dataset.attrs['network_code'] == 'GM'):
        tr_E.stats.location = '01'
    elif(dataset.attrs['network_code'] in ['II', 'US', 'NM', 'ET']):
        tr_E.stats.location = '00'

    tr_N = obspy.Trace(data=data[:, 1])
    tr_N.stats.starttime = UTCDateTime(dataset.attrs['trace_start_time'])
    tr_N.stats.delta = 0.01
    if(dataset.attrs['network_code'] == 'PB'):
        tr_N.stats.channel = dataset.attrs['receiver_type']+'2'
    else:
        tr_N.stats.channel = dataset.attrs['receiver_type']+'N'
    tr_N.stats.station = dataset.attrs['receiver_code']
    tr_N.stats.network = dataset.attrs['network_code']
    if(dataset.attrs['network_code'] == 'GM'):
        tr_N.stats.location = '01'
    elif(dataset.attrs['network_code'] in ['II', 'US', 'NM', 'ET']):
        tr_N.stats.location = '00'

    tr_Z = obspy.Trace(data=data[:, 2])
    tr_Z.stats.starttime = UTCDateTime(dataset.attrs['trace_start_time'])
    tr_Z.stats.delta = 0.01
    tr_Z.stats.channel = dataset.attrs['receiver_type']+'Z'
    tr_Z.stats.station = dataset.attrs['receiver_code']
    tr_Z.stats.network = dataset.attrs['network_code']
    if(dataset.attrs['network_code'] == 'GM'):
        tr_Z.stats.location = '01'
    elif(dataset.attrs['network_code'] in ['II', 'US', 'NM', 'ET']):
        tr_Z.stats.location = '00'

    stream = obspy.Stream([tr_E, tr_N, tr_Z])

    return stream

In [None]:
client = Client("IRIS")

chunk = "chunk2"

waveform_list = []

In [None]:
# retrieving selected waveforms from the hdf5 file:
dtfl = h5py.File(file_name, 'r')
for c, evi in enumerate(ev_list):

    if c%100 == 0:
      print(f'{c}')

    dataset = dtfl.get('data/'+str(evi))

    # waveforms, 3 channels: first row: E channel, second row: N channel, third row: Z channel
    data = np.array(dataset) # 6000x3

    # convering hdf5 dataset into obspy sream
    st = make_stream(dataset)


    inventory = client.get_stations(network=dataset.attrs['network_code'],
                                station=dataset.attrs['receiver_code'],
                                starttime=UTCDateTime(dataset.attrs['trace_start_time']),
                                endtime=UTCDateTime(dataset.attrs['trace_start_time']) + 60,
                                loc="*",
                                channel="*",
                                level="response")


    # converting into displacement
    st = make_stream(dataset)

    try:
        st = st.remove_response(inventory=inventory, output="ACC", plot=False)
    except Exception as e:
        print(f"Stream-wide response removal failed: {e}")
        traces = np.full((3, 6000), np.nan, dtype=np.float64)  # Use float32 for compatibility  # Shape: (3, datapoints)
        waveform_list.append(traces)
        continue


    st = st.remove_response(inventory=inventory, output="ACC", plot=False)

    # Convert ObsPy stream to NumPy array (shape: 3 x datapoints)
    traces = np.array([tr.data for tr in st])  # Shape: (3, datapoints)

    # Store in list
    waveform_list.append(traces)

In the following we check if it's possible to parallelize the process of conversion. If on your machine it's not possible we suggest to either run the notebook on google colab or to convert the loop into a single processor one (which should be straightforward)

In [None]:
import multiprocessing
print(multiprocessing.cpu_count())

In [None]:
import os
print(f"Max threads: {os.cpu_count()}")

In [None]:
import threading
import time
from concurrent.futures import ThreadPoolExecutor

def worker(n):
    """Simple function that simulates work by sleeping."""
    print(f"Thread {threading.current_thread().name} is running")
    time.sleep(2)  # Simulates some work
    return n

# Set number of threads (e.g., 4)
num_threads = 4

start_time = time.time()

with ThreadPoolExecutor(max_workers=num_threads) as executor:
    results = list(executor.map(worker, range(num_threads)))

end_time = time.time()

print(f"Execution time: {end_time - start_time:.2f} seconds")

# if it remains an average of 2 sec then he workers are all working and there is not a cap


## Parallel Loop

In [None]:
import h5py
import numpy as np
import os
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import shared_memory
import time

In [None]:
import numpy as np
import h5py
import time
import os
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Array
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
from obspy.core.stream import Stream

# Shared memory parameters
num_channels = 3
num_datapoints = 6000
num_events = len(ev_list)  # Make sure ev_list is defined

# Create a shared NumPy array using multiprocessing.Array
shared_array_base = Array("d", num_events * num_channels * num_datapoints)  # Shared flat array
shared_array = np.frombuffer(shared_array_base.get_obj(), dtype=np.float64)
shared_array = shared_array.reshape((num_events, num_channels, num_datapoints))
shared_array[:] = np.nan  # Initialize with NaN

client = Client("IRIS")  # Persistent FDSN client

def process_waveform(index_evi):
    """Function to process a single waveform event and write directly to shared memory."""
    index, evi = index_evi

    with h5py.File(file_name, "r") as dtfl:
        dataset = dtfl.get(f"data/{evi}")
        if dataset is None:
            return  # Skip if missing

        try:
            # Convert dataset into ObsPy stream (Assuming make_stream is defined elsewhere)
            st = make_stream(dataset)

            # Pre-fetch metadata
            inv = client.get_stations(
                network=dataset.attrs["network_code"],
                station=dataset.attrs["receiver_code"],
                starttime=UTCDateTime(dataset.attrs["trace_start_time"]),
                endtime=UTCDateTime(dataset.attrs["trace_start_time"]) + 60,
                loc="*",
                channel="*",
                level="response"
            )

            # Remove response
            st = st.remove_response(inventory=inv, output="ACC", plot=False)

            # Write results directly to shared array
            shared_array[index] = np.array([tr.data for tr in st])

        except Exception as e:
            print(f"Failed processing event {evi}: {e}")


In [None]:
# Use limited workers
# num_workers = min(os.cpu_count(), len(ev_list))
num_workers = 32

start_time = time.time()

with ProcessPoolExecutor(max_workers=num_workers) as executor:
    executor.map(process_waveform, enumerate(ev_list))

end_time = time.time()

print(f"Execution time: {end_time - start_time:.2f} seconds")

# Convert shared array into a regular NumPy array (optional)
waveform_array = shared_array.copy()
print(waveform_array.shape)

In [None]:
# Convert list to a PyTorch tensor of shape (n_samples, 3, datapoints)
waveform_tensor = torch.tensor(np.stack(waveform_array), dtype=torch.float64)  # (n_samples, 3, datapoints)

# Save the tensor to a .pt file
torch.save(waveform_tensor, f"path/to/save/{chunk}_acceleration.pt")

# Close HDF5 file
#dtfl.close()

print(f"Tensor saved as {chunk}_acceleration.pt with shape:", waveform_tensor.shape)

In [None]:
# check to see it prints the right stuff
import matplotlib.pyplot as plt

direction = 0

for idx in range(10):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  ax.plot(waveform_array[idx,direction,:], "k-")
  plt.ylabel('acceleration')
  plt.title('timesteps')
  plt.show()



In [None]:
# Convert list to a PyTorch tensor of shape (n_samples, 3, datapoints)
waveform_tensor = torch.tensor(np.stack(waveform_array), dtype=torch.float64)  # (n_samples, 3, datapoints)

# Save the tensor to a .pt file
torch.save(waveform_tensor, f"STEAD_data/{chunk}/{chunk}_acceleration.pt")

# Close HDF5 file
dtfl.close()

print(f"Tensor saved as {chunk}_acceleration.pt with shape:", waveform_tensor.shape)

In [None]:
accelerations = torch.load(f"STEAD_data/{chunk}/{chunk}_acceleration.pt")
print(accelerations.shape)

In [None]:
import matplotlib.pyplot as plt

idx = 1
direction = 0

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(accelerations[idx,direction,:], "k-")
plt.ylabel('acceleration')
plt.title('timesteps')
plt.show()