# Creating a dataset

This tutorial provides a short example on how to create a SeisBench dataset. Datasets can be created from any event catalog and waveform collection. For this example, we download an event catalog for Switzerland through FDSN. We will then download the associated waveforms through FDSN as well. We use built-in SeisBench functions for writing out the dataset in SeisBench format. In this example notebook we aim for an easy example outlining the principles of dataset creation. There are a few further considerations, in particular, for converting larger datasets, that we outline at the end.

**Note:** Some familiarity with obspy and its FDSN client is helpful for this tutorial, but not required.

In [None]:
import seisbench.data as sbd
import seisbench.util as sbu
from pathlib import Path
from obspy import read_events
from obspy import read
from obspy import Stream
import pandas as pd
from config import load_config

In [None]:
import warnings
warnings.simplefilter('ignore', DeprecationWarning)

#### Loading configuration file

In [None]:
cfg = load_config('/home/ekarkooti/ikahbasi-PhD/Dataset/Kaki/Kaki-cfg.yml')
print(cfg)

#### The event catalog

As a first step, we need an event catalog. Here, we are going to use the catalog provided by ETHZ over FDSN. For demonstration purposes, we only use a short time window.

In [None]:
network_details = pd.read_csv(cfg.path.network_details, dtype=str)
network_details.fillna(value='', inplace=True)

In [None]:
catalog = read_events(cfg.path.catalog)
catalog = [ev for ev in catalog if ev.picks != []]
catalog = catalog[:50]

In [None]:
print(len(catalog), catalog, sep='\n') # print(catalog.__str__(print_all=True))

lst = []
for ev in catalog:
    for pick in ev.picks:
        lst.append(pick.phase_hint)

for el in set(lst):
    print(el, lst.count(el))

#### Extracting the event parameters

From the catalog, we extract the event parameters and store them into a dictionary. Here, we only extract a few basic parameters on the source and its magnitude - if available. In addition, we define the split of the dataset into training/development/test partitions. We visualize one example.

In [None]:
def get_event_params(event):
    origin = event.preferred_origin()
    mag = event.preferred_magnitude()

    source_id = str(event.resource_id)

    event_params = {
        "source_id": source_id,
        "source_origin_time": str(origin.time),
        "source_origin_uncertainty_sec": origin.time_errors["uncertainty"],
        "source_latitude_deg": origin.latitude,
        "source_latitude_uncertainty_km": origin.latitude_errors["uncertainty"],
        "source_longitude_deg": origin.longitude,
        "source_longitude_uncertainty_km": origin.longitude_errors["uncertainty"],
        "source_depth_km": origin.depth / 1e3            if origin.depth else None,
        "source_depth_uncertainty_km": origin.depth_errors["uncertainty"] / 1e3           if origin.depth else None,
    }

    if mag is not None:
        event_params["source_magnitude"] = mag.mag
        event_params["source_magnitude_uncertainty"] = mag.mag_errors["uncertainty"]
        event_params["source_magnitude_type"] = mag.magnitude_type
        event_params["source_magnitude_author"] = mag.creation_info.agency_id

        if str(origin.time) < "2015-01-07":
            split = "train"
        elif str(origin.time) < "2015-01-08":
            split = "dev"
        else:
            split = "test"
        event_params["split"] = split
    return event_params

#### Extracting the trace parameters

From each pick, we extract parameters about the trace and store them in a dictionary. Again, we only extract very basic parameters. We visualize one example.

In [None]:
def get_trace_params(pick):
    net = pick.waveform_id.network_code
    sta = pick.waveform_id.station_code

    trace_params = {
        "station_network_code": net,
        "station_code": sta,
        "trace_channel": pick.waveform_id.channel_code[:2],
        "station_location_code": pick.waveform_id.location_code,
    }

    return trace_params

#### Downloading the waveforms

As a last step, we need to access the waveforms. As for the catalog, we download the waveforms from ETHZ via FDSN. Note that not for all picks we can expect to have waveforms available through FDSN, so we just return empty streams if no data is available. We visualize one example.

### Method 1

In [None]:
# def rename_stream_mapper(obspy_stream, networks, stations, locations, channels):
#     for tr in obspy_stream:
#         tr.stats.network = networks[tr.stats.network]
#         tr.stats.station = stations[tr.stats.station]
#         tr.stats.location = locations[tr.stats.location]
#         tr.stats.channel = channels[tr.stats.channel]

def reversing_dictionary(dictionary):
    return {v:k for k, v in dictionary.items()}

networks = {'': ''}
stations = {'6210': 'HANA',# 6210-Hana-shour  
            '6249': 'SORM',# 6249-Sarmak ??? SORM
            '6260': 'CHAH',# 6260-Chahgah     
            '6270': 'LAVA',# 6270-Lavar
            '6215': 'JASH',# 6215-Jashk       
            '6251': 'SANA',# 6251-Sana              
            '6261': 'SHON',# 6261-Shonbeh     
            '6289': 'ABDA',# 6289-Abdan
            '6218': 'KERD',# 6218-Kerdelan    
            '6252': 'BONY',# 6252-Bonyad            
            '6266': 'KARD',# 6266-Kardaneh
            '6219': 'DOHO',# 6219-Dohouk      
            '6255': 'BOBM',# 6255-Babmonir          
            '6268': 'ESLA',# 6268-Eslam-Abad
            '6226': 'BAGH',# 6226-Baghan      
            '6259': 'GENK',# 6259-Genkhak-Sheikhha  
            '6270': 'DARV',# 6270-drveshei
            }

locations = {'': ''}

channels = {'HHZ': 'HHZ',
            'HHN': 'HHN',
            'HHE': 'HHE'}

stationsr = reversing_dictionary(stations)

### Method 2

In [None]:
import pandas as pd
from obspy import UTCDateTime as utc
import numpy as np

df = pd.read_csv(cfg.path.network_details, dtype=str)

def rename_stream_mapper(stream, dataframe):
    for tr in stream:
        df = dataframe[dataframe.eq(tr.stats.station).any(axis=1)]
        for indx, row in df.iterrows():
            cond1 = utc(row.start) <= tr.stats.starttime
            cond2 = tr.stats.endtime <= utc(row.end)
            if cond1 and cond2:
                tr.stats.network = row.network
                tr.stats.station = row.station
                tr.stats.location = row.location
                tr.stats.channel = row.channel + tr.stats.channel[-1]

In [None]:
def get_true_station_name(dataframe, sensor_or_station, time):
    df = dataframe[dataframe.eq(sensor_or_station).any(axis=1)]
    df.reset_index(inplace=True)
    # print(df)
    if df.shape[0] == 1:
        target = df.iloc[0]
    elif df.shape[0] > 1:
        target = None
        for indx, row in df.iterrows():
            cond = utc(row.start) <= time <= utc(row.end)
            if cond:
                target = row
    else:
        print('There is not proper data in the network dataframe.')
        target = None
    return target

In [None]:
def get_waveforms(pick, root):
    station_details = get_true_station_name(
        dataframe=network_details,
        sensor_or_station=pick.waveform_id.station_code,
        time=pick.time
        )
    # stationsr = reversing_dictionary(stations)
    # station_code = stationsr.get(pick.waveform_id.station_code)
    if station_details is None:
        print(f'Station Not Found.\n{pick.waveform_id.station_code}\n')
        with open('reading-data-problem.txt', 'a') as f:
            f.write(f'Station Not Found.\n{pick.waveform_id.station_code}\n')
        return Stream()
    time = pick.time.strftime('%Y%m%d_%H')
    st = Stream()
    for channel in ['e', 'n', 'z']:
        path_data = f'{root}/{station_details.sensor}-*/gcf/{time}00{channel}.gcf'
        try:
            st += read(path_data)
        except Exception as error:
            print(f'Skip Data.\n{error}\n{path_data}\n')
            with open('reading-data-problem.txt', 'a') as f:
                f.write(f'Skip Data.\n{error}\n{path_data}\n')
    # st_trim = st.slice(starttime=pick.time-time_before,
    #                    endtime=pick.time+time_after)
    for tr in st:
        tr.stats.network = station_details.network
        tr.stats.station = station_details.station
        tr.stats.location = station_details.location
        tr.stats.channel = station_details.channel + tr.stats.channel[-1]

    return st


In [None]:
def trim_data(st, pick, before, after):
    st.trim(starttime=pick.time-before,
            endtime=pick.time+after,
            pad=True,
            nearest_sample=True,
            fill_value=0) 

In [None]:
def preprocessing_data(st):
    st.merge(-1)
    st.detrend('constant')
    st.merge(fill_value=0)


#### Writing to SeisBench format

Now, we can combine all the above functions together to write a dataset in SeisBench format. For this, we first need to define the path. For this example, we are using the current working directory. A dataset consists of 2 components:
 - a metadata file, always called `metadata.csv`, which contains all the associated properties of the waveform examples (e.g. trace parameters, source parameters etc.).
 - a waveforms file, always called `waveforms.hdf5`, containing the raw waveforms.

In [None]:
base_path = Path(cfg.path.out)
metadata_path = base_path / "metadata.csv"
waveforms_path = base_path / "waveforms.hdf5"
print(metadata_path, waveforms_path, sep='\n')

To write the dataset, we use the `WaveformDataWriter` provided by SeisBench. The writer should always be used as a context manager, i.e., using the `with` statement, as shown below. This is to ensure files are properly clsoed after writing and teardown and cleanup operations are always called when exiting the context manager.

First, we need to set the data format for our dataset. We do this by assigning a dictionary to the `writer.data_format` group.

Next, we iterate over all event and all picks in the events. Using the functions above, we generate the event and trace metadata and download the waveforms. We then convert the waveforms to a numpy array using the function `stream_to_array` provided in `seisbench.util`.

As a last step, we hand the event metadata and the waveforms as numpy array over to the writer using `add_trace`. The writer then automatically takes care of writing out the data in the correct format. It also takes care of performance optimisations that we outline in the further considerations below.

In [None]:
def select_picks(picks, station_name):
    picks = [pick for pick in picks
             if pick.waveform_id.station_code==station_name]
    picks = sorted(picks,
                   key= lambda p: p.time)
    return picks
    

In [None]:
def get_picks_time_difference(picks):
    picks_time = [pick.time for pick in picks]
    picks_time = sorted(picks_time)
    picks_difftime = [time-picks_time[0] for time in picks_time]
    return picks_difftime

In [None]:
def checking_equal_sps(stream):
    sps = stream[0].stats.sampling_rate
    assert all(tr.stats.sampling_rate == sps for tr in stream)

In [None]:
before = 60
after = 60

In [None]:
# Iterate over events and picks, write to SeisBench format
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:

    # Define data format
    writer.data_format = {
        "dimension_order": "CW",
        "component_order": "ZNE",
        "measurement": "velocity",
        "unit": "counts",
        "instrument_response": "not restituted",
    }
    n_all = len(catalog)
    for index, event in enumerate(catalog):
        # if index < 1050:
        #     continue
        if index % 100 == 0:
            print(f'{index} of {n_all} ({index/n_all*100:.2f}%)')
        # if index == 10:
        #     break

        event_params = get_event_params(event)
        stations_in_event = {pick.waveform_id.station_code for pick in event.picks}
        for station_name in stations_in_event:
            picks = select_picks(picks=event.picks,
                                 station_name=station_name)
            ###
            time_diff = get_picks_time_difference(picks)
            if max(time_diff) >= 60:
                print(f'losing pick, maximume is: {max(time_diff)}')
            ###
            pick = picks[0]
            trace_params = get_trace_params(pick)
            waveforms = get_waveforms(pick, cfg.path.stream)
            ### Preprocessing waveform
            # rename_stream_mapper(waveforms, networks, stations, locations, channels)
            preprocessing_data(st=waveforms)
            # random = np.random.uniform(-before/2, before/2)
            trim_data(waveforms, pick, before=before, after=after)
            ### Check remaining data
            if len(waveforms) == 0:
                # No waveform data available
                continue
            ###
            sampling_rate = waveforms[0].stats.sampling_rate
            # Check that the traces have the same sampling rate
            checking_equal_sps(stream=waveforms)

            actual_t_start, data, _ = sbu.stream_to_array(
                waveforms,
                component_order=writer.data_format["component_order"],
            )
            trace_params["trace_sampling_rate_hz"] = sampling_rate
            trace_params["trace_start_time"] = str(actual_t_start)

            for pick in picks:
                sample = (pick.time - actual_t_start) * sampling_rate
                trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(sample)
                trace_params[f"trace_{pick.phase_hint}_status"] = pick.evaluation_mode

            writer.add_trace({**event_params, **trace_params}, data)