In [None]:
import numpy as np
import torch

import matplotlib.pyplot as plt
import matplotlib.style as style

style.use(
    "https://raw.githubusercontent.com/dominik-strutz/dotfiles/main/mystyle.mplstyle"
)

from matplotlib import colormaps
from matplotlib.colors import ListedColormap


# Color settings
def create_modified_cmap(cmap_base, vmin=0.0, vmax=0.9):
    """
    Create a modified colormap based on a base colormap.

    Parameters:
    -----------
    cmap_base : str
        Name of the base colormap.
    vmin : float
        Minimum value for the colormap range (0.0 to 1.0).
    vmax : float
        Maximum value for the colormap range (0.0 to 1.0).

    Returns:
    --------
    ListedColormap
        A new colormap with the specified range.
    """
    cmap_values = colormaps[cmap_base].resampled(256)
    newcolors = cmap_values(np.linspace(vmin, vmax, 256))
    newcmp_blues = ListedColormap(newcolors)
    return newcmp_blues


# Create custom colormaps
newcmp_blues = create_modified_cmap("Blues")
newcmp_blues_r = newcmp_blues.reversed()

newcmp_reds = create_modified_cmap("Reds")
newcmp_reds_r = newcmp_reds.reversed()

newcmp_greens = create_modified_cmap("Greens")
newcmp_greens_r = newcmp_greens.reversed()

newcmp_purples = create_modified_cmap("Purples")
newcmp_purples_r = newcmp_purples.reversed()

In [None]:
from dased.helpers.srcloc import (
    MagnitudeRelation,
    ForwardHomogeneous,
    DataLikelihood,
)

In [None]:
from helpers import PhaseLookupCalculator

# lat, lon = 36.60, 25.65 # Agean Sea
# lat, lon = 52.5, -0.4 # UK, Midlands

# Edinburgh, UK
# lat, lon = 55.95, -3.19 # Edinburgh, UK

# Munich, Germany
lat, lon = 48.14, 11.58  # Munich, Germany


distance_grid = np.arange(0.0, 20 * 1e3, 0.1 * 1e3)
source_depth_grid = np.arange(0.4 * 1e3, 5.1 * 1e3, 0.1 * 1e3)
# receiver_depth_grid = np.arange(0.4*1e3, 0.5*1e3, 0.1*1e3)
receiver_depth_grid = np.arange(0.0 * 1e3, 0.1 * 1e3, 0.1 * 1e3)


phase_lookup = PhaseLookupCalculator(
    lat,
    lon,
    distance_grid,
    source_depth_grid,
    receiver_depth_grid,
)

phase_lookup.plot_velocity_model()

lookup_ds = phase_lookup(
    {
        "p": ["p", "P", "P\\", "p\\"],
        "s": ["s", "S", "S\\", "s\\"],
    }
)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8))

wavetype = "p"

im = ax1.contourf(
    lookup_ds[wavetype]["distance"] * 1e-3,
    lookup_ds[wavetype]["source_depth"] * 1e-3,
    lookup_ds[wavetype]["incidence_angle"][0].T,
    levels=np.arange(0, 91, 10),
    cmap="viridis",
    vmax=40,
)

cbar = plt.colorbar(im, ax=ax1, shrink=0.6)

im = ax2.pcolormesh(
    lookup_ds[wavetype]["distance"] * 1e-3,
    lookup_ds[wavetype]["source_depth"] * 1e-3,
    lookup_ds[wavetype]["arrival_time"][0].T,
    cmap="viridis",
    shading="nearest",
)
cbar = plt.colorbar(im, ax=ax2, shrink=0.6)

for ax in (ax1, ax2):
    # ax.set_aspect('equal')
    ax.set_xlabel("Distance [km]")
    ax.set_ylabel("Source depth [km]")

    # ax.set_ylim(0, 5)
    # ax.set_xlim(0, 20)

    ax.invert_yaxis()

plt.show()

In [None]:
from dased.helpers.srcloc import ForwardLayeredLookup

reference_distance = 10e3  # km
f_max = 10.0  # Hz

distance_relation_P = MagnitudeRelation(
    magnitude_factor=0.437,
    log_coeff=-1.269,
    reference_distance=reference_distance,
)
distance_relation_S = MagnitudeRelation(
    magnitude_factor=0.69,
    log_coeff=-1.588,
    reference_distance=reference_distance,
    #    reference_relation=distance_relation_P
)

forward_function_P_layered = ForwardLayeredLookup(
    "data/p_lookup_48.14_11.58.nc",
    wave_type="P",
    distance_relation=distance_relation_P,
    sensor_type="strain",
    # theta_sigma=10.0 * math.pi / 180,
)
forward_function_S_layered = ForwardLayeredLookup(
    "data/s_lookup_48.14_11.58.nc",
    wave_type="S",
    distance_relation=distance_relation_S,
    sensor_type="strain",
    # theta_sigma=10.0 * math.pi / 180,
)

forward_function_P = ForwardHomogeneous(
    velocity=2500,
    wave_type="P",
    distance_relation=distance_relation_P,  # incidence_max=np.deg2rad(40)
)

forward_function_S = ForwardHomogeneous(
    velocity=1500,
    wave_type="S",
    distance_relation=distance_relation_S,  # incidence_max=np.deg2rad(40)
)

forward_function_SH = ForwardHomogeneous(
    velocity=1500,
    wave_type="SH",
    distance_relation=distance_relation_S,  # incidence_max=np.deg2rad(40)
)

forward_function_SV = ForwardHomogeneous(
    velocity=1500,
    wave_type="SV",
    distance_relation=distance_relation_S,  # incidence_max=np.deg2rad(40)
)

data_likelihood = DataLikelihood(
    forward_function=dict(
        P=forward_function_P,
        S=forward_function_S,
    ),
    std_corr=0.05,
    std_uncorr=0.05,
    cor_length=0.0,
    f_max=f_max,
    std_cutoff=10.0,
    K_sh=10.0,
)

In [None]:
# # ==== CONFIGURABLE PARAMETERS ====
# # Plotting grid parameters
# distances = torch.linspace(0.1, reference_distance, 200)  # km
# thetas = torch.linspace(0, 2 * np.pi, 360)  # radians
# # phi_waves = np.deg2rad([0, 30, 60, 90])  # Multiple phi values in radians
# phi_waves = np.deg2rad([0, 10, 30, 60, 90])  # Multiple phi values in radians

# # Plot limits
# snr_vmin, snr_vmax = 1, 5  # SNR range
# std_vmin, std_vmax = 0.0, 0.1  # Uncertainty range

# # ==== CALCULATIONS ====
# # Create grid for plotting
# distance_grid, theta_grid = torch.meshgrid(distances, thetas, indexing="ij")


# # ==== PLOTTING FUNCTION ====
# def plot_wave_characteristics(wave_type, forward_function):
#     """
#     Plot SNR and uncertainty for a given wave type.

#     Parameters:
#     -----------
#     wave_type : str
#         Type of wave ('P' or 'S')
#     forward_function : ForwardHomogeneous
#         Forward function to calculate SNR
#     figsize : tuple
#         Base figure size (width per column, height)
#     fontsize_title : int
#         Font size for phi angle titles
#     """
#     n_cols = len(phi_waves)
#     width, height = (2.5, 5)
#     fig = plt.figure(figsize=(width * n_cols, height))

#     # Create grid for subplots
#     grid = fig.add_gridspec(2, n_cols + 1, width_ratios=[1] * n_cols + [0.05])

#     # Create axes for colorbars
#     cbar_snr_ax = fig.add_subplot(grid[0, -1])
#     cbar_std_ax = fig.add_subplot(grid[1, -1])

#     axes = [
#         [fig.add_subplot(grid[row, col], projection="polar") for col in range(n_cols)]
#         for row in range(2)
#     ]

#     for col, phi_wave in enumerate(phi_waves):
#         # Calculate SNR and STD for this phi angle
#         snr = forward_function.snr_from_angles(
#             distance_grid,
#             phi1=theta_grid,
#             phi2=phi_wave,  # Using phi1 for wave direction
#         )
#         std = data_likelihood.snr2std(snr)

#         # SNR plot (top row)
#         pcm_snr = axes[0][col].pcolormesh(
#             theta_grid,
#             distance_grid,
#             snr,
#             vmin=snr_vmin,
#             vmax=snr_vmax,
#             cmap=newcmp_blues,
#             shading="auto",
#         )
#         axes[0][col].set_title(
#             f"φ={np.rad2deg(phi_wave):.0f}°",
#             fontsize=14,
#         )

#         # STD plot (bottom row)
#         pcm_std = axes[1][col].pcolormesh(
#             theta_grid,
#             distance_grid,
#             std,
#             vmin=std_vmin,
#             vmax=std_vmax,
#             cmap=newcmp_blues_r,
#             shading="auto",
#         )

#         # Common settings for both plots
#         for row in [0, 1]:
#             ax = axes[row][col]
#             ax.set_ylim(0, reference_distance)
#             ax.set_xticks([])
#             ax.set_yticks([])
#             ax.grid(False)

#     # Large Rotated Text on the left indicating the wave type
#     ax_text = fig.add_subplot(grid[:, :])
#     ax_text.set_xticks([])
#     ax_text.set_yticks([])
#     ax_text.grid(False)
#     for sp in ax_text.spines.values():
#         sp.set_visible(False)

#     ax_text.set_ylabel(f"{wave_type} wave", fontsize=16, weight="bold", rotation=90)

#     # Add colorbars
#     cbar_snr_ax.set_box_aspect(15)
#     cbar_std_ax.set_box_aspect(15)

#     plt.colorbar(pcm_snr, cax=cbar_snr_ax, label="SNR", extend="both")
#     plt.colorbar(pcm_std, cax=cbar_std_ax, label="std dev [s]", extend="both")

#     plt.show()


# plot_wave_characteristics("P", forward_function_P)
# plot_wave_characteristics("S", forward_function_S)
# plot_wave_characteristics("SH", forward_function_SH)
# plot_wave_characteristics("SV", forward_function_SV)

In [None]:
import geopandas as gpd
from shapely.geometry import Point

# Changeable Parameters
# ----------------------
# n_radial = 200
# n_azimuth = 360

n_radial = 100
n_azimuth = 180

max_radius = reference_distance
# source_depths = [5e3, 10e3, 15e3]
source_depths = [-500.0, -1000.0, -2500.0]  # in meters
N_depth = len(source_depths)

# Plot configuration
# plot_components = ['traveltime', 'sensitivity', 'snr']  # Choose which to plot
plot_components = ["sensitivity", "snr"]  # Choose which to plot

figsize_per_plot = (2.0, 2.0)  # Base size for each subplot

# Create polar grid
# -----------------
radii = torch.linspace(0, max_radius, n_radial)
azimuths = torch.linspace(0.0, 2 * np.pi, n_azimuth)
r_grid, a_grid = torch.meshgrid(radii, azimuths, indexing="ij")
x = r_grid * torch.cos(a_grid)
y = r_grid * torch.sin(a_grid)

# Create GeoDataFrame with station locations
# -----------------------------------------
points = [
    Point(x[i, j].item(), y[i, j].item())
    for i in range(n_radial)
    for j in range(n_azimuth)
]
design = gpd.GeoDataFrame(
    {
        "geometry": points,
        "elevation": [0.0] * len(points),
        "u_x": [1.0] * len(points),
        "u_y": [0.0] * len(points),
        "u_z": [0.0] * len(points),
    }
)

# Plot results for P and S waves
# ------------------------------
for wave_type, forward_function in [
    ("P (homogeneous)", forward_function_P),
    ("S (homogeneous)", forward_function_S),
    ("P (layered)", forward_function_P_layered),
    ("S (layered)", forward_function_S_layered),
]:
    # Calculate figure size based on components and source depths
    n_rows = len(plot_components)
    n_cols = len(source_depths)
    figsize = (figsize_per_plot[0] * (n_cols + 0.2), figsize_per_plot[1] * n_rows)

    fig = plt.figure(figsize=figsize, dpi=120)
    gs = fig.add_gridspec(n_rows, n_cols + 1, width_ratios=[1] * n_cols + [0.05])

    axes = []
    for row in range(n_rows):
        row_axes = []
        for col in range(n_cols):
            ax = fig.add_subplot(gs[row, col], projection="polar")
            ax.set_theta_zero_location("E")
            ax.set_theta_direction(-1)
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])
            row_axes.append(ax)
        axes.append(row_axes)

    for i, depth in enumerate(source_depths):
        # Calculate wave data
        m = torch.tensor([[0, 0, depth]])
        times, snr = forward_function(m, design)
        sensitivity = forward_function.directional_sensitivity(m, design)

        # Reshape for plotting
        times_grid = times.reshape(n_radial, n_azimuth)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)

        cmap = {
            "P (homogeneous)": newcmp_blues,
            "S (homogeneous)": newcmp_blues,
            "P (layered)": newcmp_blues,
            "S (layered)": newcmp_blues,
        }.get(wave_type, newcmp_blues)

        # Plot selected components
        for j, component in enumerate(plot_components):
            if component == "traveltime":
                pcm = axes[j][i].pcolormesh(
                    a_grid.T,
                    r_grid.T / 1000,
                    times_grid.T,
                    cmap="Blues",
                    shading="auto",
                )
                title = f"Depth {depth/1000:.1f} km"
                label = "Time (s)"
                row_label = f"{wave_type} Travel Time"
            elif component == "sensitivity":
                pcm = axes[j][i].pcolormesh(
                    a_grid.T,
                    r_grid.T / 1000,
                    sensitivity_grid.T,
                    cmap=newcmp_blues,
                    shading="auto",
                    vmin=0,
                    vmax=1,
                )
                label = "Sensitivity"
                row_label = f"{wave_type} Sensitivity"
                axes[j][i].contour(
                    a_grid.T,
                    r_grid.T / 1000,
                    sensitivity_grid.T,
                    levels=[0.2, 0.4, 0.6, 0.8, 1.0],
                    colors="k",
                    linewidths=0.5,
                    linestyles="--",
                    alpha=0.5,
                )
            elif component == "snr":
                snr_plot = snr_grid.clamp(0, 5)
                pcm = axes[j][i].pcolormesh(
                    a_grid.T,
                    r_grid.T / 1000,
                    snr_plot.T,
                    cmap=newcmp_purples,
                    shading="auto",
                    vmin=1,
                    vmax=5,
                )
                label = "SNR"
                row_label = f"{wave_type} SNR"
                axes[j][i].contour(
                    a_grid.T,
                    r_grid.T / 1000,
                    snr_plot.T,
                    levels=[0, 1, 2, 3, 4, 5],
                    colors="k",
                    linewidths=0.5,
                    linestyles="--",
                    alpha=0.5,
                )
            title = f"Depth {-depth/1000:.1f} km"
            if j == 0:
                axes[j][i].set_title(title, fontsize=8)

            # Add max radius text in the northeast quadrant
            axes[j][i].text(
                -np.pi / 4,
                max_radius / 1000 * 1.05,
                rf"$R_\text{{max}}: {max_radius/1000:.0f} \text{{km}}$",
                ha="left",
                va="bottom",
                fontsize=8,
            )

            # Add row labels to the first column
            if i == 0:
                axes[j][i].set_ylabel(row_label, fontsize=8, rotation=90, labelpad=15)

            # Add colorbar for each row
            if i == n_cols - 1:
                cax = fig.add_subplot(gs[j, -1])
                cax.set_box_aspect(12)
                cbar = plt.colorbar(pcm, cax=cax, label=label, extend="both", pad=0.05)
                if j == 0:
                    cbar.ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
                else:
                    cbar.ax.set_yticks([0, 1, 2, 3, 4, 5])

    plt.show()

In [None]:
wave_types = [
    ("P", forward_function_P, forward_function_P_layered),
    ("S", forward_function_S, forward_function_S_layered),
]
column_titles = ["Homogeneous", "Horizontally Layered"]

# Set up figure dimensions with extra space between models
n_models = 2
n_empty = 1  # number of empty columns between models
n_total_cols = n_models * N_depth + n_empty

snr_vmin, snr_vmax = 1, 5  # SNR range

fig = plt.figure(figsize=(3 * N_depth + 6, 3 * 3 + 1))  # Add space for titles
gs = fig.add_gridspec(
    5,
    n_total_cols + 1,
    width_ratios=[1] * N_depth
    + [0.2]
    + [1] * N_depth
    + [0.1],  # 0.2 for empty space, 0.1 for colorbar
    height_ratios=[0.1, 1, 1, 1, 1],
    wspace=0.1,
    hspace=0.2,
)

# Create arrays for the axes
axes = np.empty((4, n_total_cols), dtype=object)

# Create all subplots, skipping the empty column
for i in range(4):
    for j in range(n_total_cols):
        if j == N_depth:  # skip the empty column
            continue
        axes[i, j] = fig.add_subplot(gs[i + 1, j], projection="polar")
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])
        axes[i, j].grid(False)

# Add model titles at the top
for model_idx in range(2):
    if model_idx == 0:
        title_col_start = 0
        title_col_end = N_depth - 1
    else:
        title_col_start = N_depth + 1
        title_col_end = N_depth * 2
    title_ax = fig.add_subplot(gs[0, title_col_start : title_col_end + 1])
    title_ax.text(
        0.5,
        1.2,
        column_titles[model_idx],
        ha="center",
        va="bottom",
        fontsize=18,
        fontweight="bold",
    )
    title_ax.axis("off")

# Add depth headers
for depth_idx, depth in enumerate(source_depths):
    # Homogeneous
    header_ax = fig.add_subplot(gs[0, depth_idx])
    header_ax.text(
        0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
    )
    header_ax.axis("off")
    # Layered
    header_ax = fig.add_subplot(gs[0, N_depth + 1 + depth_idx])
    header_ax.text(
        0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
    )
    header_ax.axis("off")

# Add R_max at the center of the figure
fig.text(
    0.5,
    0.835,
    rf"$R_\text{{max}}: {max_radius/1000:.0f} \text{{km}}$",
    ha="center",
    va="center",
    fontsize=15,
    fontweight="bold",
    color="k",
    zorder=100,
)

# Create plots for each wave type, sensitivity/SNR, and model type
for wave_idx, (wave_label, f_hom, f_layered) in enumerate(wave_types):
    snr_row = 2 * wave_idx + 1
    sens_row = 2 * wave_idx
    for depth_idx, depth in enumerate(source_depths):
        # Homogeneous
        col = depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_hom(m, design)
        sensitivity = f_hom.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        if col == 0:
            axes[snr_row, 0].set_ylabel(f"{wave_label} SNR", fontsize=16, labelpad=30)
            axes[sens_row, 0].set_ylabel(
                f"{wave_label} Sensitivity", fontsize=16, labelpad=30
            )
        # Layered
        col = N_depth + 1 + depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_layered(m, design)
        sensitivity = f_layered.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        # axes[sens_row, col].text(
        #     -np.pi / 4,
        #     max_radius / 1000 * 1.05,
        #     rf"$R_\text{{max}}: {max_radius/1000:.0f} \text{{km}}$",
        #     ha="left",
        #     va="bottom",
        #     fontsize=8,
        # )

# Add colorbars
for i in range(4):
    cbar_ax = fig.add_subplot(gs[i + 1, -1])
    cbar_ax.set_box_aspect(8)
    if i % 2 == 0:  # Sensitivity rows
        cbar = plt.colorbar(pcm_sens, cax=cbar_ax, shrink=0.4, extend="both", pad=0.2)
        cbar.ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    else:  # SNR rows
        cbar = plt.colorbar(pcm_snr, cax=cbar_ax, shrink=0.4, extend="both", pad=0.2)
        cbar.ax.set_yticks([1, 2, 3, 4, 5])

fig.savefig("figures/sensitivity_snr_bodywaves.png", dpi=300, bbox_inches="tight")
fig.savefig("figures/sensitivity_snr_bodywaves.pdf", dpi=300, bbox_inches="tight")

plt.show()

In [None]:
wave_types = [
    ("P", forward_function_P, forward_function_P_layered),
    ("S", forward_function_S, forward_function_S_layered),
]
column_titles = ["Homogeneous", "Horizontally Layered"]

# Set up figure dimensions with extra space between models
n_models = 2
n_empty = 1  # number of empty columns between models
n_total_cols = n_models * N_depth + n_empty

snr_vmin, snr_vmax = 1, 5  # SNR range

# Create two separate figures - one for P waves and one for S waves
for wave_idx, (wave_label, f_hom, f_layered) in enumerate(wave_types):
    fig = plt.figure(
        figsize=(3 * N_depth + 6, 3 * 2)
    )  # Reduced height for only 2 rows per figure
    gs = fig.add_gridspec(
        3,
        n_total_cols + 1,
        width_ratios=[1] * N_depth
        + [0.2]
        + [1] * N_depth
        + [0.1],  # 0.2 for empty space, 0.1 for colorbar
        height_ratios=[0.1, 1, 1],
        wspace=0.1,
        hspace=0.1,
    )

    # Create arrays for the axes - 2 rows per figure
    axes = np.empty((2, n_total_cols), dtype=object)

    # Create all subplots, skipping the empty column
    for i in range(2):
        for j in range(n_total_cols):
            if j == N_depth:  # skip the empty column
                continue
            axes[i, j] = fig.add_subplot(gs[i + 1, j], projection="polar")
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
            axes[i, j].grid(False)

    # Add model titles at the top
    for model_idx in range(2):
        if model_idx == 0:
            title_col_start = 0
            title_col_end = N_depth - 1
        else:
            title_col_start = N_depth + 1
            title_col_end = N_depth * 2
        title_ax = fig.add_subplot(gs[0, title_col_start : title_col_end + 1])
        title_ax.text(
            0.5,
            1.2,
            column_titles[model_idx],
            ha="center",
            va="bottom",
            fontsize=18,
            fontweight="bold",
        )
        title_ax.axis("off")

    # Add depth headers
    for depth_idx, depth in enumerate(source_depths):
        # Homogeneous
        header_ax = fig.add_subplot(gs[0, depth_idx])
        header_ax.text(
            0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
        )
        header_ax.axis("off")
        # Layered
        header_ax = fig.add_subplot(gs[0, N_depth + 1 + depth_idx])
        header_ax.text(
            0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
        )
        header_ax.axis("off")

    # Add R_max at the center of the figure
    fig.text(
        0.5,
        0.79,
        rf"$R_\text{{max}}: {max_radius/1000:.0f} \text{{km}}$",
        ha="center",
        va="center",
        fontsize=15,
        fontweight="bold",
        color="k",
        zorder=100,
    )

    # For this figure, we only process the current wave type
    snr_row = 1  # SNR is bottom row in each figure
    sens_row = 0  # Sensitivity is top row in each figure

    for depth_idx, depth in enumerate(source_depths):
        # Homogeneous
        col = depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_hom(m, design)
        sensitivity = f_hom.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        if col == 0:
            axes[snr_row, 0].set_ylabel(
                f"{wave_label} SNR", fontsize=16, labelpad=30, fontweight="bold"
            )
            axes[sens_row, 0].set_ylabel(
                f"{wave_label} Sensitivity", fontsize=16, labelpad=30, fontweight="bold"
            )
        # Layered
        col = N_depth + 1 + depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_layered(m, design)
        sensitivity = f_layered.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )

    # Add colorbars
    for i in range(2):
        cbar_ax = fig.add_subplot(gs[i + 1, -1])
        cbar_ax.set_box_aspect(8)
        if i == 0:  # Sensitivity row
            cbar = plt.colorbar(
                pcm_sens, cax=cbar_ax, shrink=0.4, extend="both", pad=0.2
            )
            cbar.ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
        else:  # SNR row
            cbar = plt.colorbar(
                pcm_snr, cax=cbar_ax, shrink=0.4, extend="both", pad=0.2
            )
            cbar.ax.set_yticks([1, 2, 3, 4, 5])

    # Save figures with wave label in filename
    fig.savefig(
        f"figures/sensitivity_snr_{wave_label}waves.png", dpi=300, bbox_inches="tight"
    )
    fig.savefig(
        f"figures/sensitivity_snr_{wave_label}waves.pdf", dpi=300, bbox_inches="tight"
    )

plt.show()


In [None]:
depths = []
vp = []
vs = []

fig, ax = plt.subplots(figsize=(3, 4))


for layer in phase_lookup.model.layers():
    top = layer.ztop
    bottom = layer.zbot
    depths.extend([top, bottom])

    vp_i = layer.material(top).vp
    vs_i = layer.material(top).vs

    vp.extend([vp_i, vp_i])
    vs.extend([vs_i, vs_i])

depths = np.array(depths)
vp = np.array(vp)
vs = np.array(vs)

ax.plot(vp / 1e3, depths / 1e3, label="Vp (P-wave)", color="b")
ax.plot(vs / 1e3, depths / 1e3, label="Vs (S-wave)", color="r")


ax.set_ylim(
    0,
    max(
        np.max(phase_lookup.source_depth_grid), np.max(phase_lookup.receiver_depth_grid)
    )
    / 1e3
    * 2,
)

ax.invert_yaxis()
ax.set_xlabel("Velocity (km/s)")
ax.set_ylabel("Depth (km)")
ax.legend()
ax.set_title("Velocity Model")
plt.show()

In [None]:
wave_types = [
    ("P", forward_function_P, forward_function_P_layered),
    ("S", forward_function_S, forward_function_S_layered),
]
column_titles = ["Homogeneous", "Horizontally Layered"]

# Set up figure dimensions with extra space between models
n_models = 2
n_empty = 1  # number of empty columns between models
n_total_cols = n_models * N_depth + n_empty

snr_vmin, snr_vmax = 1, 5  # SNR range

# Create two separate figures - one for P waves and one for S waves
for wave_idx, (wave_label, f_hom, f_layered) in enumerate(wave_types):
    fig = plt.figure(
        figsize=(3 * N_depth + 10, 3 * 2)
    )  # slightly increased width to fit velocity model
    # Add one extra column for the velocity model (placed before the colorbar)
    gs = fig.add_gridspec(
        3,
        n_total_cols + 3,
        width_ratios=[1] * N_depth + [0.2] + [1] * N_depth + [0.5] + [0.5] + [1.0],
        height_ratios=[0.1, 1, 1],
        wspace=0.1,
        hspace=0.1,
    )

    # Create arrays for the axes - 2 rows per figure (only for the polar plots)
    axes = np.empty((2, n_total_cols), dtype=object)

    # Create all subplots, skipping the empty column
    for i in range(2):
        for j in range(n_total_cols):
            if j == N_depth:  # skip the empty column between models
                continue
            axes[i, j] = fig.add_subplot(gs[i + 1, j], projection="polar")
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
            axes[i, j].grid(False)

    # Add model titles at the top
    for model_idx in range(2):
        if model_idx == 0:
            title_col_start = 0
            title_col_end = N_depth - 1
        else:
            title_col_start = N_depth + 1
            title_col_end = N_depth * 2
        title_ax = fig.add_subplot(gs[0, title_col_start : title_col_end + 1])
        title_ax.text(
            0.5,
            1.2,
            column_titles[model_idx],
            ha="center",
            va="bottom",
            fontsize=18,
            fontweight="bold",
        )
        title_ax.axis("off")

    # Add depth headers
    for depth_idx, depth in enumerate(source_depths):
        # Homogeneous
        header_ax = fig.add_subplot(gs[0, depth_idx])
        header_ax.text(
            0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
        )
        header_ax.axis("off")
        # Layered
        header_ax = fig.add_subplot(gs[0, N_depth + 1 + depth_idx])
        header_ax.text(
            0.5, 0.4, f"z={-depth/1000:.1f} km", ha="center", va="top", fontsize=14
        )
        header_ax.axis("off")

    # Add R_max at the center of the figure
    fig.text(
        0.42,
        0.79,
        rf"$R_\text{{max}}: {max_radius/1000:.0f} \text{{km}}$",
        ha="center",
        va="center",
        fontsize=15,
        fontweight="bold",
        color="k",
        zorder=100,
    )

    # For this figure, we only process the current wave type
    snr_row = 1  # SNR is bottom row in each figure
    sens_row = 0  # Sensitivity is top row in each figure

    for depth_idx, depth in enumerate(source_depths):
        # Homogeneous
        col = depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_hom(m, design)
        sensitivity = f_hom.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        if col == 0:
            axes[snr_row, 0].set_ylabel(
                f"{wave_label} SNR", fontsize=16, labelpad=30, fontweight="bold"
            )
            axes[sens_row, 0].set_ylabel(
                f"{wave_label} Sensitivity", fontsize=16, labelpad=30, fontweight="bold"
            )
        # Layered
        col = N_depth + 1 + depth_idx
        m = torch.tensor([[0, 0, depth]])
        times, snr = f_layered(m, design)
        sensitivity = f_layered.directional_sensitivity(m, design)
        snr_grid = snr.reshape(n_radial, n_azimuth)
        sensitivity_grid = sensitivity.reshape(n_radial, n_azimuth)
        pcm_snr = axes[snr_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            cmap=newcmp_greens,
            shading="auto",
            vmin=snr_vmin,
            vmax=snr_vmax,
            rasterized=True,
        )
        axes[snr_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            snr_grid.T.clamp(snr_vmin, snr_vmax),
            levels=[1, 2, 3, 4, 5],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )
        pcm_sens = axes[sens_row, col].pcolormesh(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            cmap=newcmp_blues,
            shading="auto",
            vmin=0,
            vmax=1,
            rasterized=True,
        )
        axes[sens_row, col].contour(
            a_grid.T,
            r_grid.T / 1000,
            sensitivity_grid.T,
            levels=[0.2, 0.4, 0.6, 0.8],
            colors="k",
            linewidths=0.5,
            linestyles="--",
            alpha=0.5,
        )

    # Add a velocity-model subplot to the right (spanning both data rows)
    vel_ax = fig.add_subplot(gs[1:, -1])
    # Plot step Vp and Vs (convert to km/s and depth to km)
    vel_ax.step(
        vp / 1000.0,
        depths / 1000.0,
        where="post",
        label="Vp",
        color="tab:blue",
        linewidth=2,
    )
    vel_ax.step(
        vs / 1000.0,
        depths / 1000.0,
        where="post",
        label="Vs",
        color="tab:red",
        linewidth=2,
    )
    vel_ax.set_ylim(0, 5_000 / 1_000.0)
    vel_ax.set_xlim(0, 8)

    vel_ax.invert_yaxis()
    vel_ax.set_xlabel("Velocity (km/s)", fontsize=12)
    vel_ax.set_ylabel("Depth (km)", fontsize=12)
    vel_ax.set_title(
        "Horizontally Layered\nVelocity Model",
        fontsize=12,
    )
    vel_ax.legend(frameon=False)
    # vel_ax.grid(True, linestyle=':', linewidth=0.5)
    # Limit depth range to full model extent

    # Add colorbars
    for i in range(2):
        cbar_ax = fig.add_subplot(gs[i + 1, -3])
        cbar_ax.set_box_aspect(8)
        if i == 0:  # Sensitivity row
            cbar = plt.colorbar(
                pcm_sens, cax=cbar_ax, shrink=0.4, extend="both", pad=2.5
            )
            cbar.ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
        else:  # SNR row
            cbar = plt.colorbar(
                pcm_snr, cax=cbar_ax, shrink=0.4, extend="both", pad=2.5
            )
            cbar.ax.set_yticks([1, 2, 3, 4, 5])

    # Save figures with wave label in filename
    fig.savefig(
        f"figures/sensitivity_snr_{wave_label}waves_plus_velmodel.png",
        dpi=300,
        bbox_inches="tight",
    )
    fig.savefig(
        f"figures/sensitivity_snr_{wave_label}waves_plus_velmodel.pdf",
        dpi=300,
        bbox_inches="tight",
    )

plt.show()


In [None]:
from dased.criteria import RaySensitivity
import numpy as np
import torch
import matplotlib.pyplot as plt

# Original parameters for Rayleigh/Love plots
n_points = 100
reference_distance_tom_inter = 10_000.0  # km
x_range = (-reference_distance_tom_inter, reference_distance_tom_inter)
y_range = (-reference_distance_tom_inter, reference_distance_tom_inter)

rayleigh_criterion = RaySensitivity(
    n_points=n_points,
    reference_distance=reference_distance_tom_inter,
    x_range=x_range,
    y_range=y_range,
    data_type="rayleigh",
)
love_criterion = RaySensitivity(
    n_points=n_points,
    reference_distance=reference_distance_tom_inter,
    x_range=x_range,
    y_range=y_range,
    data_type="love",
)

# Create grid for theta1 and theta2 for Rayleigh/Love plots
n_theta = 360
# These theta1 and theta2 are for the main SNR scaling plots
theta1_main = torch.linspace(-1 * np.pi, 1 * np.pi, n_theta)
theta2_main = torch.linspace(-1 * np.pi, 1 * np.pi, n_theta)
theta1_grid, theta2_grid = torch.meshgrid(theta1_main, theta2_main, indexing="ij")

# Set phi angle and distance (not used in scaling calculations but required by method signature)
phi_main = 0  # radians (for main plots)
distance_main = 1.0  # arbitrary value (for main plots)

# Calculate scaling factors for Rayleigh and Love waves
scaling_grid_rayleigh = rayleigh_criterion._snr_scaling_pair(
    phi_main, theta1_grid.numpy(), theta2_grid.numpy()
)
scaling_grid_love = love_criterion._snr_scaling_pair(
    phi_main, theta1_grid.numpy(), theta2_grid.numpy()
)

# ==== Definitions for the diagram subplot (inspired by cell 6) ====
point1_diag = np.array([1, 0])
theta1_diag_deg = 45  # degrees

point2_diag = np.array([0, 2])
theta2_diag_deg = 135  # degrees

# Convert angles to radians for diagram
theta1_diag_rad = np.deg2rad(theta1_diag_deg)
theta2_diag_rad = np.deg2rad(theta2_diag_deg)

# Define vector directions for diagram (unit vectors)
v1_diag = np.array([np.cos(theta1_diag_rad), np.sin(theta1_diag_rad)])
v2_diag = np.array([np.cos(theta2_diag_rad), np.sin(theta2_diag_rad)])


def draw_semi_circle_diag(ax, center, angle_deg, radius=0.15, color="k", lw=1.5):
    """
    Draw an arc from East (0 deg, positive x-axis) anticlockwise to angle_deg.
    Mark East with a slightly longer line.
    """
    theta_arc_points = np.linspace(0, np.deg2rad(angle_deg), 100)
    x_arc = center[0] + radius * np.cos(theta_arc_points)
    y_arc = center[1] + radius * np.sin(theta_arc_points)
    ax.plot(x_arc, y_arc, color=color, lw=lw, zorder=-1, clip_on=False)

    east_line_x_end = center[0] + radius * 1.3
    east_line_y_end = center[1]
    ax.plot(
        [center[0], east_line_x_end],
        [center[1], east_line_y_end],
        color=color,
        lw=lw,
        zorder=-1,
        clip_on=False,
    )


midpoint_diag = (point1_diag + point2_diag) / 2
angle_between_diag_rad = np.arctan2(
    point2_diag[1] - point1_diag[1], point2_diag[0] - point1_diag[0]
)

# ==== Create Figure and Subplots ====
# Adjusted figsize: original (12,5) for 2 plots. New plot is 0.5 width of one main plot.
# If each main plot is conceptually 1 unit wide, diagram is 0.5. Total 2.5 units.
# If (12,5) means each main plot is (6,5), new diagram is (3,5). Total width (15,5).
fig = plt.figure(figsize=(14, 5))
# GridSpec: 1 row, 3 columns for plots. width_ratios for diagram, Rayleigh, Love.
gs = fig.add_gridspec(1, 3, width_ratios=[0.4, 1, 1], wspace=0.4)

ax_diagram = fig.add_subplot(gs[0, 0])
axes_rayleigh = fig.add_subplot(gs[0, 1])
axes_love = fig.add_subplot(gs[0, 2])

# ==== Plot Diagram (on ax_diagram) ====
ax_diagram.quiver(
    *point1_diag,
    *v1_diag,
    angles="xy",
    scale_units="xy",
    pivot="middle",
    scale=2,
    width=0.02,
    color="k",
    headlength=0,
    headaxislength=0,
    headwidth=0,
    clip_on=False,
)
ax_diagram.quiver(
    *point2_diag,
    *v2_diag,
    angles="xy",
    scale_units="xy",
    pivot="middle",
    scale=2,
    width=0.02,
    color="k",
    headlength=0,
    headaxislength=0,
    headwidth=0,
    clip_on=False,
)
ax_diagram.scatter(
    point1_diag[0],
    point1_diag[1],
    marker="o",
    color="k",
    s=50,
    zorder=20,
    clip_on=False,
)
ax_diagram.scatter(
    point2_diag[0],
    point2_diag[1],
    marker="o",
    color="k",
    s=50,
    zorder=20,
    clip_on=False,
)

ax_diagram.plot(
    [point1_diag[0], point2_diag[0]],
    [point1_diag[1], point2_diag[1]],
    "k--",
    clip_on=False,
    linewidth=1.0,
)
draw_semi_circle_diag(ax_diagram, point1_diag, theta1_diag_deg, color="tab:blue")
draw_semi_circle_diag(ax_diagram, point2_diag, theta2_diag_deg, color="tab:blue")
draw_semi_circle_diag(
    ax_diagram, midpoint_diag, np.rad2deg(angle_between_diag_rad), color="tab:blue"
)
ax_diagram.text(
    point1_diag[0] + 0.2,
    point1_diag[1] + 0.03,
    r"$\theta_1$",
    fontsize=14,
    ha="left",
    va="bottom",
)
ax_diagram.text(
    point2_diag[0] + 0.1,
    point2_diag[1] + 0.1,
    r"$\theta_2$",
    fontsize=14,
    ha="left",
    va="bottom",
)
ax_diagram.text(
    midpoint_diag[0] + 0.15,
    midpoint_diag[1] + 0.15,
    r"$\phi$",
    fontsize=14,
    ha="right",
    va="bottom",
)
ax_diagram.set_xlim(-0.25, 1.5)
ax_diagram.set_ylim(-1, 3)
ax_diagram.set_aspect("equal")
ax_diagram.axis("off")  # Remove all axes, spines, ticks, and labels
ax_diagram.set_rasterization_zorder(1)

# ==== Plot for Rayleigh waves (on axes_rayleigh) ====
pcm_rayleigh = axes_rayleigh.pcolormesh(
    theta1_grid,
    theta2_grid,
    scaling_grid_rayleigh,
    cmap="Reds",
    vmin=0,
    vmax=1.0,
    shading="gouraud",
    rasterized=True,
)
fig.colorbar(
    pcm_rayleigh, ax=axes_rayleigh, pad=0.05, shrink=0.6, label="Scaling Factor"
)
axes_rayleigh.set_title("Rayleigh Wave\nSNR Scaling")
axes_rayleigh.set_rasterization_zorder(0)

# ==== Plot for Love waves (on axes_love) ====
pcm_love = axes_love.pcolormesh(
    theta1_grid,
    theta2_grid,
    scaling_grid_love,
    cmap="Blues",
    vmin=0,
    vmax=1.0,
    shading="gouraud",
    rasterized=True,
)
fig.colorbar(pcm_love, ax=axes_love, pad=0.05, shrink=0.6, label="Scaling Factor")
axes_love.set_title("Love Wave\nSNR Scaling")
axes_love.set_rasterization_zorder(0)

# Common settings for Rayleigh and Love plots
for ax_main in [axes_rayleigh, axes_love]:
    ax_main.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    ax_main.set_xticklabels(
        [r"$-180°$", r"$-90°$", r"$0°$", r"$90°$", r"$180°$"], fontsize=12
    )
    ax_main.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    ax_main.set_yticklabels(
        [r"$-180°$", r"$-90°$", r"$0°$", r"$90°$", r"$180°$"], fontsize=12
    )
    ax_main.set_xlabel(r"$\theta_1 - \phi$", fontsize=14)
    ax_main.set_ylabel(r"$\theta_2 - \phi$", fontsize=14)
    ax_main.set_aspect("equal")
    ax_main.grid(True, linestyle="--", alpha=0.5, color="k")


fig.savefig("figures/rayleigh_love_snr_scaling.png", dpi=300, bbox_inches="tight")
fig.savefig("figures/rayleigh_love_snr_scaling.pdf", dpi=300, bbox_inches="tight")

plt.show()