In [None]:
import numpy as np
from obspy import Stream, Trace, UTCDateTime, read_inventory
from obspy.core.trace import Stats
from obspy.signal.trigger import recursive_sta_lta, trigger_onset
from obspy.core.inventory import Inventory, Network, Station, Channel, Response
from typing import Dict, Any, Union, List

def merge_responses(sensor_response, datalogger_response):
    # Assuming the datalogger response is to be appended to the sensor response stages
    response_stages = sensor_response.response_stages + datalogger_response.response_stages
    return Response(response_stages=response_stages)

def adjust_station_name(sta, index):
    sta_base = sta[:-1] if len(sta) > 4 else sta
    new_sta = f"{sta_base}{index}"
    if len(new_sta) > 5:
        new_sta = f"{sta_base[:-1]}{index}"
    return new_sta

def save_seismogram(gen_data: Dict[str, Any], magnitude: float, net: Union[str, List[str]],
                    sta: Union[str, List[str]], depth: float = 0.0, sampling_rate: float = 100.0,
                    latitude: Union[float, List[float]] = 0.0, longitude: Union[float, List[float]] = 0.0,
                    num_samples: int = 100, sensor_file: str = None, datalogger_file: str = None) -> None:
    """
    Save seismogram data with station metadata to MiniSEED and STATIONXML files.
    Parameters:
    - gen_data: Dict containing 'waveforms' and 'cond'.
    - magnitude: Magnitude of the event.
    - depth: Depth of the event in kilometers.
    - sampling_rate: Sampling rate of the seismogram in Hz.
    - latitude: Latitude of the event (single value or list of values).
    - longitude: Longitude of the event (single value or list of values).
    - net: Network code (single value or list of values).
    - sta: Station code (single value or list of values).
    - num_samples: Number of samples.
    - sensor_file: Path to the sensor XML file containing the response information.
    - datalogger_file: Path to the datalogger XML file containing the response information.
    """
    # Define trigger threshold
    trigger_on = 1.5
    trigger_off = 1.0
    origin_time = UTCDateTime()
    # Handle latitude and longitude input
    num_waveforms = gen_data['waveforms'].shape[0]
    latitudes = np.full(num_waveforms, latitude) if isinstance(latitude, (float, int)) else np.array(latitude)
    longitudes = np.full(num_waveforms, longitude) if isinstance(longitude, (float, int)) else np.array(longitude)
    stream = Stream()
    inventory = Inventory(networks=[], source="HighFEM-user")
    # Read the responses from the sensor and datalogger XML files
    if sensor_file:
        sensor_inventory = read_inventory(sensor_file)
    if datalogger_file:
        datalogger_inventory = read_inventory(datalogger_file)
    if isinstance(net, str):
        net = [net] * num_waveforms
    if isinstance(sta, str):
        sta = [adjust_station_name(sta, i + 1) for i in range(num_waveforms)]
    for it in range(num_waveforms):
        network = Network(
            code=net[it],
            stations=[],
            description="Station This Quake Does Not Exist network",
            start_date=origin_time - 3600*24*30  # station was deployed exactly a month before the earthquake
        )
        for i in range(it*num_samples, it*num_samples + num_samples):
            # Station inventory
            station = Station(
                code=sta[it],
                creation_date=origin_time - 3600*24*30,
                latitude=latitudes[it],
                longitude=longitudes[it],
                elevation=gen_data['hypocentral_distance'][i],  # epi_distance
                total_number_of_channels=3,
                description=(
                    f"Shallow Crustal: {gen_data['is_shallow_crustal'][i]}, "  # is shallow? 0 for not, 1 for yes
                    f"Vs30: {gen_data['vs30'][i]} m/s, "  # vs30
                    "Elevation is hypocentral distance here"  # add unit information
                ),  # is shallow? 0 for not, 1 for yes
                restricted_status=str(gen_data['vs30'][i])  # vs30
            )
            for j in range(gen_data['waveforms'].shape[1]):
                # Define the onset of waveform
                sta_lta_ratio = recursive_sta_lta(
                    gen_data['waveforms'][i, j, :],
                    int(1.5 * sampling_rate),
                    int(4 * sampling_rate)
                )
                triggers = trigger_onset(sta_lta_ratio, trigger_on, trigger_off)
                onset = np.amin(triggers) if len(triggers) > 0 else 0
                time_onset = onset / sampling_rate
                # Implement start time
                Vp = 6.0  # km/s
                est_distance = np.sqrt(gen_data['hypocentral_distance'][i]**2 + depth**2)
                time_init = est_distance / Vp
                # Stats
                stats = Stats()
                stats.network = net[it]
                stats.station = sta[it]
                stats.location = '00'
                stats.calib = '1.0'
                if j == 0:
                    stats.channel = 'BHN' #high-freq virtual Z
                elif j == 1:
                    stats.channel = 'BHE' #high-freq virtual N
                elif j == 2:
                    stats.channel = 'BHZ' #high-freq virtual E
                stats.sampling_rate = sampling_rate
                stats.npts = gen_data['waveforms'][i, j, :].shape[0]
                stats.starttime = origin_time + time_init - time_onset
                stats['units'] = 'm/s^2'
                # Trace data
                trace = Trace(data=gen_data['waveforms'][i, j, :], header=stats)
                stream.append(trace)
                # Channel inventory
                channel = Channel(
                    code=stats.channel,
                    location_code="00",
                    sample_rate=sampling_rate,
                    latitude=latitudes[it],
                    longitude=longitudes[it],
                    elevation=gen_data['hypocentral_distance'][i],  # epi_distance
                    depth=0.0,
                    description="Amplitude in m/s^2"
                )
                # Add the merged response to the channel
                if sensor_file and datalogger_file:
                    sensor_response = sensor_inventory[0][0][0].response
                    datalogger_response = datalogger_inventory[0][0][0].response
                    merged_response = merge_responses(sensor_response, datalogger_response)
                    channel.response = merged_response
                station.channels.append(channel)
            network.stations.append(station)
        inventory.networks.append(network)
    # Save everything
    output_file = f"data_gm0/HighFEM_tqdne_{magnitude}_Japan"
    stream.write(output_file + ".mseed", format='MSEED')
    inventory.write(output_file + ".xml", format='STATIONXML')

In [None]:
# Example usage:
sensor_file = 'EEW/GenericUnitySensor.xml'
datalogger_file = 'EEW/GenericUnityDataLogger.xml'
save_seismogram(files, magnitude=mag, depth=depth, sampling_rate=100, latitude=lat, longitude=lon, net='HF', sta='HFEM', num_samples=1, sensor_file=sensor_file, datalogger_file=datalogger_file)

In [None]:
import matplotlib.pyplot as plt
from obspy import read, read_inventory
import numpy as np
import matplotlib.colors as mcolors
from matplotlib import cm
from matplotlib.colorbar import ColorbarBase
plt.rc('font', family='serif')
plt.rcParams.update({'font.size': 16})
# INPUT
traces = read("data_gm0/HighFEM_tqdne_5.232177479920197_Japan.mseed")
inv = read_inventory("data_gm0/HighFEM_tqdne_5.232177479920197_Japan.xml")
scale = 50
ntraces = 25
# Colormap
cmap = cm.plasma
def select_maximally_spaced_values(values, n):
    if n > len(values):
        raise ValueError("n cannot be greater than the number of elements in the list")
    # Initialize the result list with the first and last elements
    result = [0]
    step = len(values) / (n - 1)
    for i in range(1, n - 1):
        idx = int(round(i * step))
        result.append(idx)
    result.append(-1)
    return result
for tr in traces:
    # Get epicentral distance from stationxml
    tr.stats.network = "HF"
    dist = inv.get_channel_metadata(tr.id)["elevation"]
    # Add epicentral distance in meters
    tr.stats.distance = 1000. * dist
    tr.stats.vs30 = float(inv.select(station=tr.stats.station)[0][0].restricted_status)
# Trim all to one starttime
traces.trim(starttime=min([tr.stats.starttime for tr in traces]), endtime=max([tr.stats.endtime for tr in traces]), pad=True, fill_value=0.)
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
# Normalize the vs30 values
norm = mcolors.Normalize(vmin=min([tr.stats.vs30 for tr in traces]), vmax=0.75 * max([tr.stats.vs30 for tr in traces]))
norm = mcolors.Normalize(vmin=100, vmax=900)
# Generate colors based on the normalized values
colors = cmap(norm([tr.stats.vs30 for tr in traces]))
# Choose traces well distributed over distance
distances = [tr.stats.distance for tr in traces]  # in m
t_offset_guessed_1 = min(distances) / 6000.  # time 0 should be time of P-wave
t_offset_guessed_2 = 500 / traces[0].stats.sampling_rate # minus the pre-event time
t_offset = t_offset_guessed_1 - t_offset_guessed_2
ixs = select_maximally_spaced_values(distances, n=ntraces)
for ix in ixs:
    tr = traces[ix]
    t = np.linspace(t_offset, t_offset + (tr.stats.npts - 1) * tr.stats.delta, tr.stats.npts)
    plt.plot(t, tr.data * scale + tr.stats.distance / 1000., color=colors[ix], alpha=0.8, linewidth=0.8)
# Add reference lines
h2, = plt.plot(t, t * 6., ":", color="k")
h3, = plt.plot(t, t * 3.3, "--", color="k")
plt.legend([h2, h3], [r"V$_P$=6 km/s", r"V$_S$=3.3 km/s"])
# Create a ScalarMappable object for the colorbar
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # We don't need data for the colorbar
# Add the colorbar
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('V$_{S30}$ [m/s]')
plt.xlim(0, t.max())
plt.ylim(min(distances) / 1000. * 0.5, max(distances) / 1000. * 1.1)
plt.xlabel("Time after origin [s]")
plt.ylabel("Scaled waveform at distance [km]")
plt.savefig("figures/section_test.png", dpi=300, bbox_inches='tight')
plt.show()