# Importing Important modules

In [None]:
import ipynbname
import logging
import pkg_resources
import seisbench.data as sbd
import seisbench.util as sbu

from pathlib import Path
from obspy import read_events
from obspy import read
import pandas as pd
import os
from datetime import datetime
import sys

In [None]:
import warnings
warnings.simplefilter('ignore', DeprecationWarning)

In [None]:
lib_path = [r'C:\Users\ikahbasi\OneDrive\Applications\GitHub\SeisRoutine',
            r'C:\Users\ikahb\OneDrive\Applications\GitHub\SeisRoutine']
for path in lib_path:
    sys.path.append(path)
##########################################################################
import SeisRoutine.catalog as src
import SeisRoutine.waveform as srw
import SeisRoutine.config as srconf
import SeisRoutine.statistics as srs

In [None]:
from importlib import reload  # Python 3.4+
src = reload(src)
srw = reload(srw)

# Define Some Functions

In [None]:
def getting_filename_and_path_of_the_running_code():
    """
    Get the filename and directory path of the currently executing code.
    
    This function works for both regular Python scripts (.py files) and Jupyter Notebooks
    (.ipynb files). For notebooks, it handles both VS Code's environment and standard
    Jupyter environments.

    Returns:
        tuple: A tuple containing (directory_path, filename) of the running code.
        
    Note:
        In Jupyter Notebook environments, returns the notebook name and path.
        In regular Python scripts, returns the script name and path.
    """
    _file = sys.argv[0]
    name = os.path.basename(_file)
    path = os.path.dirname(_file)
    if name == "ipykernel_launcher.py":
        try:
            _file = globals()['__vsc_ipynb_file__']
            name = os.path.basename(_file)
            path = os.path.dirname(_file)
        except Exception as error:
            print(error)
            name = ipynbname.name()
            path = ipynbname.path()
    return path, name

In [None]:
class get_data:
    def __init__(self, root, pattern_path):
        self.root = root
        self.pattern_path = pattern_path
        self.stream = None
        self.stats  = None

    def read(self, time):
        pattern = self.pattern_path.format(time=time)
        path = f'{self.root}/{pattern}'
        logging.info(f'Reading Data: {path}')
        self.stream = read(path)
        self.preprocessing_data()
        self.stations = list({tr.stats.station for tr in self.stream})

    def get_data_related_to_pick(self, pick):
        if self.stream is None:
            self.read(time=pick.time)
        if not pick.waveform_id.station_code in self.stations:
            self.read(time=pick.time)
        if not pick.time.julday == self.stream[0].stats.starttime.julday:
            self.read(time=pick.time)
        target_stream = self.stream.select(station=pick.waveform_id.station_code)
        return target_stream
    
    def preprocessing_data(self):
        self.sps_check()
        self.stream.merge(-1)
        self.stream.detrend('constant')
        self.stream.merge()
        # self.stream.merge(method=1, fill_value=0)
        # self.stream.filter('bandpass', freqmin=0.5, freqmax=49, zerophase=True)
    
    def sps_check(self, sps=100):
        # print('Available sps:', {tr.stats.sampling_rate for tr in self.stream})
        assert all(tr.stats.sampling_rate==sps for tr in self.stream)

In [None]:
def get_source_params(event):
    origin = event.preferred_origin()
    mag = event.preferred_magnitude()

    source_id = str(event.resource_id)

    event_params = {
        "source_id": source_id,
        "source_origin_time": str(origin.time),
        "source_origin_uncertainty_sec": origin.time_errors["uncertainty"],
        "source_latitude_deg": origin.latitude,
        "source_latitude_uncertainty_km": origin.latitude_errors["uncertainty"],
        "source_longitude_deg": origin.longitude,
        "source_longitude_uncertainty_km": origin.longitude_errors["uncertainty"],
        "source_depth_km": origin.depth,
        "source_depth_uncertainty_km": origin.depth_errors["uncertainty"],
    }
    ### Unit conversion
    if event_params['source_depth_km']:
        event_params['source_depth_km'] /= 1e3
    if event_params['source_depth_uncertainty_km']:
        event_params['source_depth_uncertainty_km'] /= 1e3
    if event_params['source_latitude_uncertainty_km']:
        event_params['source_latitude_uncertainty_km'] *= 111
    if event_params['source_longitude_uncertainty_km']:
        event_params['source_latitude_uncertainty_km'] *= 111

    if mag is not None:
        event_params["source_magnitude"] = mag.mag
        event_params["source_magnitude_uncertainty"] = mag.mag_errors["uncertainty"]
        event_params["source_magnitude_type"] = mag.magnitude_type
        event_params["source_magnitude_author"] = mag.creation_info.agency_id
        event_params["split"] = None
    return event_params

In [None]:
def make_station_parameters_of_network_details(df):
    network = {}
    for index, row in df.iterrows():
        network[row.station] = {
            'station_code': row.station,
            'station_network_code': row.network,
            'station_location_code': row.location,
            'station_latitude_deg': row.latitude,
            'station_longitude_deg': row.longitude,
            'station_elevation_m': row.elevation,
            'station_sensitivity_counts_spm': None,
            'station_sensor': row.sensor,
            'station_region': row.region,
            }
    return network


In [None]:
def get_trace_params(pick):
    net = pick.waveform_id.network_code
    sta = pick.waveform_id.station_code
    trace_params = {
        "station_network_code": net,
        "station_code": sta,
        "trace_channel": pick.waveform_id.channel_code[:2],
        "station_location_code": pick.waveform_id.location_code,
        "evaluation_mode": pick.evaluation_mode}
    return trace_params

In [None]:
def get_phase_params(pick, event):
    origin = event.preferred_origin()
    arrival = src.select_arrival_related_to_the_pick(pick=pick,
                                                     arrivals=origin.arrivals)
    if arrival:
        phase_params = arrival.__dict__.copy()
        for key in ['resource_id', 'pick_id', 'phase', 'takeoff_angle_errors',
                    'horizontal_slowness_residual',
                    'horizontal_slowness_weight']:
            phase_params.pop(key)
        phase_params = {f'{key}_{pick.phase_hint}': val
                        for key, val in phase_params.items()}
    else:
        phase_params = {}
    return phase_params

In [None]:
import numpy as np
from scipy.stats import skew, zscore

In [None]:
from statistics import median

In [None]:
def stream_to_array_ikahbasi(stream, component_order):
    """
    Converts stream of single station waveforms into a numpy array according to a given component order.
    If trace start and end times disagree between component traces, remaining parts are filled with zeros.
    Also returns completeness, i.e., the fraction of samples in the output that actually contain data.
    Assumes all traces to have the same sampling rate.

    :param stream: Stream to convert
    :type stream: obspy.Stream
    :param component_order: Component order
    :type component_order: str
    :return: starttime, data, completeness
    :rtype: UTCDateTime, np.ndarray, float
    """
    starttime = min(trace.stats.starttime for trace in stream)
    endtime = max(trace.stats.endtime for trace in stream)
    sampling_rate = stream[0].stats.sampling_rate

    samples = int((endtime - starttime) * sampling_rate) + 1

    completeness = 0.0
    data = np.zeros((len(component_order), samples), dtype="float64")
    for c_idx, c in enumerate(component_order):
        c_stream = stream.select(channel=f"??{c}")
        gap_found = False
        if c_stream.get_gaps():
            gap_found = True
        completeness = 0.0
        for trace in c_stream:
            if not gap_found:
                tr_data = trace.data
            else:
                tr_data = trace.data.data
            start_sample = int((trace.stats.starttime - starttime) * sampling_rate)
            l = min(len(tr_data), samples - start_sample)
            data[c_idx, start_sample : start_sample + l] = tr_data[:l]
    nans = np.isnan(data)
    completeness = np.sum(~np.isnan(data)) / data.size
    return starttime, data, completeness

# Initializing the init file and starting logging.

In [None]:
init_cfg = srconf.load_config('0-init-cfg.yml')
cfg_path = os.path.join(init_cfg.target_config_filepath,
                        init_cfg.target_config_filename)
cfg = srconf.load_config(cfg_path)
#
today_str = datetime.today().strftime('%Y-%m-%dT%H-%M-%S')
cfg.mk_dataset.path.outputs.dataset = cfg.mk_dataset.path.outputs.dataset.format(datetime_str=today_str)

In [None]:
cfg_path = os.path.join(init_cfg.target_config_filepath,
                        init_cfg.target_config_filename)
cfg_path

In [None]:
with open(cfg_path.replace('cfg', 'last_run-cfg'), 'w') as file:
    cfg.to_yaml(stream=file, default_flow_style=False, indent=4)

In [None]:
srconf.configure_logging(level=cfg.log.level,
                         log_format=cfg.log.format,
                         mode=cfg.log.mode, colored_console=True,
                         filepath=cfg.mk_dataset.path.outputs.dataset,
                         filename_prefix=cfg.log.filename_prefix,
                         filename=cfg.mk_dataset.path.outputs.log.filename)

In [None]:
log_separator = "+" * 80

In [None]:
nb_path, nb_name = getting_filename_and_path_of_the_running_code()
msg = (f"Logging has started for notebook: {nb_name}.\n"
       f"This file is located at: {nb_path}\n"
       )
logging.info(msg)
logging.info(f"Separator: {log_separator}")

In [None]:
# List all installed packages and their versions
imported_modules = {name.split('.')[0] for name in globals() if name in sys.modules}
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
msg = "Packages List:\n"
for package in sorted(installed_packages.keys()):
    # if package in imported_modules:
    version = installed_packages[package]
    msg += f"{package}=={version}\n"
logging.info(msg)
logging.info(f"Separator: {log_separator}")

In [None]:
msg = cfg.__str__()
logging.info(f'Configuration File:\n{msg}')
logging.info(f"Separator: {log_separator}")

# Loading Seismic Catalog and network details.

In [None]:
catalog = read_events(cfg.mk_dataset.path.inputs.catalog)
catalog = [ev for ev in catalog if ev.picks != []]

In [None]:
network_details = pd.read_csv(cfg.mk_dataset.path.inputs.network_details, dtype=str)
network_details.fillna(value='', inplace=True)
network_parameters = make_station_parameters_of_network_details(df=network_details)
stations_network_list = network_parameters.keys()

In [None]:
### Just for Ahar
d = {'SDHR': {}, 'JIGH': {}}
for ev in catalog:
    for pick in ev.picks:
        if pick.waveform_id.station_code in ('SDHR', 'JIGH'):
            if pick.time.julday in d[pick.waveform_id.station_code].keys():
                d[pick.waveform_id.station_code][pick.time.julday] += 1
            else:
                d[pick.waveform_id.station_code][pick.time.julday] = 1
            # d[pick.waveform_id.station_code].append(pick.time.julday)

In [None]:
# otime = [ev.preferred_origin().time.timestamp for ev in catalog]
# import matplotlib.pyplot as plt
# _ = plt.hist(otime)

In [None]:
src.print_phase_frequency(catalog, case_sensitivity=False)

#### Extracting the event parameters

From the catalog, we extract the event parameters and store them into a dictionary. Here, we only extract a few basic parameters on the source and its magnitude - if available. In addition, we define the split of the dataset into training/development/test partitions. We visualize one example.

#### Extracting the trace parameters

From each pick, we extract parameters about the trace and store them in a dictionary. Again, we only extract very basic parameters. We visualize one example.

In [None]:
def do_or_dont_have_related_arrival(catalog):
    dont = 0
    do = 0
    for ev in catalog:
        origin = ev.preferred_origin()
        for pick in ev.picks:
            arrival = src.select_arrival_related_to_the_pick(
                pick=pick, arrivals=origin.arrivals)
            if arrival==False:
                dont += 1
            else:
                do += 1
    return do, dont
    
do, dont = do_or_dont_have_related_arrival(catalog)
msg = ("Pick arrival situation:\n"
       f"{do} Picks have related arrivals.\n"
       f"{dont} Picks don't have related arrivals.")
logging.info(msg)
logging.info(f"Separator: {log_separator}")

#### Writing to SeisBench format

Now, we can combine all the above functions together to write a dataset in SeisBench format. For this, we first need to define the path. For this example, we are using the current working directory. A dataset consists of 2 components:
 - a metadata file, always called `metadata.csv`, which contains all the associated properties of the waveform examples (e.g. trace parameters, source parameters etc.).
 - a waveforms file, always called `waveforms.hdf5`, containing the raw waveforms.

To write the dataset, we use the `WaveformDataWriter` provided by SeisBench. The writer should always be used as a context manager, i.e., using the `with` statement, as shown below. This is to ensure files are properly clsoed after writing and teardown and cleanup operations are always called when exiting the context manager.

First, we need to set the data format for our dataset. We do this by assigning a dictionary to the `writer.data_format` group.

Next, we iterate over all event and all picks in the events. Using the functions above, we generate the event and trace metadata and download the waveforms. We then convert the waveforms to a numpy array using the function `stream_to_array` provided in `seisbench.util`.

As a last step, we hand the event metadata and the waveforms as numpy array over to the writer using `add_trace`. The writer then automatically takes care of writing out the data in the correct format. It also takes care of performance optimisations that we outline in the further considerations below.

In [None]:
get_stream = get_data(cfg.mk_dataset.path.inputs.stream_root,
                      cfg.mk_dataset.path.inputs.stream_pattern)

In [None]:
base_path = Path(cfg.mk_dataset.path.outputs.dataset)
metadata_path = base_path / "metadata.csv"
waveforms_path = base_path / "waveforms.hdf5"
###
msg = ("Dataset will be save at:\n"
       + str(metadata_path)
       + '\n'
       + str(waveforms_path))
###
if cfg.mk_dataset.save_streams:
    stream_path = base_path / "mseed"
    os.makedirs(stream_path, exist_ok=True)
    msg += '\n'+ str(stream_path)

logging.info(msg)
logging.info(f"Separator: {log_separator}")

In [None]:
# Iterate over events and picks, write to SeisBench format
n_passed_picks = 0
n_all_picks = do
n_all_events = len(catalog)
n_events_step = 100
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:

    # Define data format
    writer.data_format = {
        "dimension_order":     cfg.mk_dataset.parameters.dimension_order,
        "component_order":     cfg.mk_dataset.parameters.component_order,
        "sampling_rate":       cfg.mk_dataset.parameters.sampling_rate,
        "measurement":         cfg.mk_dataset.parameters.measurement,
        "unit":                cfg.mk_dataset.parameters.unit,
        "instrument_response": cfg.mk_dataset.parameters.instrument_response,
    }
    for n_passed_events, event in enumerate(catalog):
        origin = event.preferred_origin()
        ########################################################################
        ## Selecting stations to processing
        ## Option 1: using all stations that exist in the event catalog.
        ##           (variable: stations_event_list)
        ## Option 2: using just stations that exist in the network details file.
        ##           (variable: stations_network_list)
        ########################################################################
        stations_event_list = {pick.waveform_id.station_code
                               for pick in event.picks}
        _stations = stations_network_list
        for station_name in _stations:
            picks = src.select_picks(picks=event.picks,
                                     station_name=station_name)
            if picks == []:
                continue
            ###
            pick = picks[0]
            stime = pick.time - cfg.mk_dataset.cut_time.before
            etime = pick.time + cfg.mk_dataset.cut_time.after
            ### Reading Data
            st = get_stream.get_data_related_to_pick(pick=pick)
            st = st.slice(starttime=stime,
                          endtime=etime,
                          nearest_sample=True)
            # It's possible that all data were masked! If not split,
            # N empty traces exist and len(st) shows N.
            st = st.split()
            ### Check remaining data
            if len(st) == 0:
                msg = ('There is No WaveForms After Slicing!!!\n'
                       f'Station: {station_name}\n'
                       f'Otime:   {origin.time}\n'
                       f'Pick:    {pick.time}')
                logging.warning(msg)
                continue
            
            st.detrend('constant')
            st.merge()
            ### Check that the traces have the same sampling rate
            srw.waveform.uni_sps(st=st, )
            # starttime, data, completeness = sbu.stream_to_array(
            #     stream=st,
            #     component_order=cfg.mk_dataset.parameters.component_order)
            starttime, data, completeness = stream_to_array_ikahbasi(
                stream=st,
                component_order=cfg.mk_dataset.parameters.component_order)

            ########################
            ### Trace parameters ###
            ########################
            trace_params = {}
            tr = st[0]
            sps = tr.stats.sampling_rate
            # trace_params['trace_name'] = 'Kaki'
            trace_params["trace_start_time"] = str(starttime)
            trace_params["trace_npts"] = data.shape[-1]
            trace_params["trace_sampling_rate_hz"] = sps
            trace_params["trace_dt_s"] = tr.stats.delta
            trace_params["trace_channel"] = tr.stats.channel[:2]
            trace_params["trace_category"] = "earthquake"
            trace_params["trace_completeness"] = completeness
            component_order = cfg.mk_dataset.parameters.component_order
            ###
            pick = sorted(picks, key=lambda x: x.time)[0]
            sample = (pick.time - starttime) * sps
            sample = int(round(sample))
            snr = srw.health_check.routine.compute_snr(
                data=data, pick_idx=sample,
                noise_window=100, signal_window=200, axis=1, domain='time')
            epsilon = 1e-10
            snr[snr<epsilon] += epsilon
            tmp = {'snr': snr,
                   'snr_db': 20 * np.log10(snr)}
            for key, vals in tmp.items():
                for cha, val in zip(component_order, vals):
                    trace_params[f"trace_{cha}_{key}"] = val
            ###
            tmp = {'median': np.median(data, axis=1),
                   'mean': np.mean(data, axis=1),
                   'rms': np.sqrt(np.mean(np.power(data, 2), axis=1)),
                   'max': np.max(data, axis=1),
                   'min': np.min(data, axis=1),
                   'lower_quartile': [np.percentile(_, 25) for _ in data],
                   'upper_quartile': [np.percentile(_, 75) for _ in data],
                   'gap': np.sum(np.isnan(data), axis=1)}
            for key, vals in tmp.items():
                for cha, val in zip(component_order, vals):
                    trace_params[f"trace_{cha}_{key}_counts"] = val
            ### Note: I changed the 'trace_*_arrival_sample' keywords.
            for pick in picks:
                sample = (pick.time - starttime) * sps
                sample = int(round(sample))
                hint = pick.phase_hint
                hint = hint if hint!='' else None
                trace_params[f"trace_{hint}_arrival_sample"] = sample
                trace_params[f"trace_{hint}_status"] = pick.evaluation_mode

            #########################
            ### Source parameters ###
            #########################
            source_params = get_source_params(event)

            ##########################
            ### Station parameters ###
            ##########################
            station_params = network_parameters[station_name]

            #######################
            ### Path parameters ###
            #######################
            path_params = {}
            
            for pick in picks:
                arrival = src.select_arrival_related_to_the_pick(
                    pick=pick, arrivals=origin.arrivals)
                if arrival:
                    hint = pick.phase_hint
                    hint = hint if hint!='' else None
                    path_params[f'path_{hint}_travel_s'] = pick.time - origin.time
                    path_params[f'path_{hint}_residual_s'] = arrival.time_residual
                    #
                    azimuth = arrival.azimuth
                    if azimuth:
                        if azimuth < 180:
                            back_azimuth = azimuth + 180
                        elif azimuth >= 180:
                            back_azimuth = azimuth + 180
                        path_params['path_azimuth_deg'] = azimuth
                        path_params['path_back_azimuth_deg'] = back_azimuth
                    if arrival.distance:
                        path_params['path_ep_distance_km'] = arrival.distance * 111

            ############################
            ### Write In The DataSet ###
            ############################
            writer.add_trace(waveform=data,
                             metadata={**trace_params,
                                       **source_params,
                                       **station_params,
                                       **path_params})
            
            #####################
            ### Saving stream ###
            #####################
            if cfg.mk_dataset.save_streams:
                otime = origin.time
                otime = otime.replace('-', '').replace(':', '')[:-5]
                stream_name = f'{n_passed_events}_{otime}_{station_name}.msd'
                st.write(stream_path/stream_name, format='MSEED')
            n_passed_picks += len(picks)
        n_passed_events += 1
        ### Write log
        if n_passed_events % n_events_step == 0:
            msg = ("Passed Processes:\n"
                   f"{n_passed_events} of {n_all_events} events passed "
                   f"({n_passed_events/n_all_events*100:.2f}%).\n"
                   f"{n_passed_picks} of {n_all_picks} picks passed "
                   f"({n_passed_picks/n_all_picks*100:.2f}%).")
            logging.info(msg)

msg = ("Passed Processes:\n"
       f"{n_passed_events} of {n_all_events} events passed "
       f"({n_passed_events/n_all_events*100:.2f}%).\n"
       f"{n_passed_picks} of {n_all_picks} picks passed "
       f"({n_passed_picks/n_all_picks*100:.2f}%).")
logging.info(msg)