In [None]:
import numpy as np
import matplotlib
import matplotlib_inline.backend_inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import LinearSegmentedColormap
from scipy.interpolate import RBFInterpolator
from matplotlib import path

matplotlib_inline.backend_inline.set_matplotlib_formats("retina")

import celeri

# Plotting functions

In [None]:
def get_mesh_indices(estimation):
    n_meshes = len(estimation.model.meshes)
    mesh_start_idx = np.zeros(n_meshes, dtype=int)
    mesh_end_idx = np.zeros(n_meshes, dtype=int)

    for i in range(len(estimation.model.meshes)):
        if i == 0:
            mesh_start_idx[i] = 0
            mesh_end_idx[i] = estimation.model.meshes[i].n_tde
        else:
            mesh_start_idx[i] = mesh_end_idx[i]
            mesh_end_idx[i] = mesh_start_idx[i - 1] + estimation.model.meshes[i].n_tde

    return n_meshes, mesh_start_idx, mesh_end_idx


def plot_segments(p, segment):
    """Elements common to all subplots
    Args:
        segment (pd.DataFrame): Fault segments
        lon_range (Tuple): Longitude range (min, max)
        lat_range (Tuple): Latitude range (min, max)
    """
    for i in range(len(segment)):
        plt.plot(
            [segment.lon1[i], segment.lon2[i]],
            [segment.lat1[i], segment.lat2[i]],
            "-k",
            linewidth=p.segment_line_width_outer,
        )
    for i in range(len(segment)):
        plt.plot(
            [segment.lon1[i], segment.lon2[i]],
            [segment.lat1[i], segment.lat2[i]],
            "-w",
            linewidth=p.segment_line_width_inner,
        )


def plot_common_elements(p):
    """Elements common to all subplots
    Args:
        segment (pd.DataFrame): Fault segments
        lon_range (Tuple): Longitude range (min, max)
        lat_range (Tuple): Latitude range (min, max)
    """
    plt.xlim([p.lon_range[0], p.lon_range[1]])
    plt.ylim([p.lat_range[0], p.lat_range[1]])
    plt.xticks(p.lon_ticks)
    plt.yticks(p.lat_ticks)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xlabel("longitude (degrees)", fontsize=p.fontsize)
    plt.ylabel("latitude (degrees)", fontsize=p.fontsize)
    plt.tick_params(labelsize=p.fontsize)


def plot_vel_arrows_elements(p, station, east_velocity, north_velocity, arrow_scale):
    # Draw velocity vectors
    velocity_magnitude = np.sqrt(east_velocity**2.0 + north_velocity**2.0)
    norm = Normalize()
    norm.autoscale(velocity_magnitude)
    norm.vmin = p.arrow_magnitude_min
    norm.vmax = p.arrow_magnitude_max
    colormap = p.arrow_colormap
    quiver_handle = plt.quiver(
        station.lon,
        station.lat,
        east_velocity,
        north_velocity,
        scale=p.arrow_scale_default * arrow_scale,
        width=p.arrow_width,
        scale_units="inches",
        color=colormap(norm(velocity_magnitude)),
        linewidth=p.arrow_linewidth,
        edgecolor=p.arrow_edgecolor,
    )

    # Draw white background rectangle
    rect = mpatches.Rectangle(
        p.key_rectangle_anchor,
        p.key_rectangle_width,
        p.key_rectangle_height,
        fill=True,
        color=p.key_background_color,
        linewidth=p.key_linewidth,
        ec=p.key_edgecolor,
    )
    plt.gca().add_patch(rect)

    # Draw arrow legend
    plt.quiverkey(
        quiver_handle,
        p.key_arrow_lon,
        p.key_arrow_lat,
        p.key_arrow_magnitude,
        p.key_arrow_text,
        coordinates="data",
        color=p.key_arrow_color,
        fontproperties={"size": p.fontsize},
    )

    plt.gca().set_aspect("equal")
    plt.show()


def plot_mesh(meshes, fill_value):
    x_coords = meshes.points[:, 0]
    y_coords = meshes.points[:, 1]
    vertex_array = np.asarray(meshes.verts)

    ax = plt.gca()
    xy = np.c_[x_coords, y_coords]
    verts = xy[vertex_array]
    pc = matplotlib.collections.PolyCollection(verts, edgecolor="none", cmap="rainbow")
    pc.set_array(fill_value)
    ax.add_collection(pc)
    ax.autoscale()
    plt.colorbar(pc, fraction=0.046, pad=0.04)

    # Add mesh edge
    x_edge = x_coords[meshes.ordered_edge_nodes[:, 0]]
    y_edge = y_coords[meshes.ordered_edge_nodes[:, 0]]
    x_edge = np.append(x_edge, x_coords[meshes.ordered_edge_nodes[0, 0]])
    y_edge = np.append(y_edge, y_coords[meshes.ordered_edge_nodes[0, 0]])
    plt.plot(x_edge, y_edge, color="black", linewidth=1)


def smooth_irregular_data(x_coords, y_coords, values, length_scale):
    # Build a KDTree for efficient neighbor searching
    points = np.vstack((x_coords, y_coords)).T
    tree = cKDTree(points)

    # Prepare an array to store the smoothed values
    smoothed_values = np.zeros_like(values)

    # Smoothing calculation
    for i, point in enumerate(points):
        # Find neighbors within 3 * length_scale for efficiency
        indices = tree.query_ball_point(point, 3 * length_scale)

        # Calculate distances and apply Gaussian weights
        distances = np.linalg.norm(points[indices] - point, axis=1)
        weights = np.exp(-(distances**2) / (2 * length_scale**2))

        # Weighted sum for smoothing
        smoothed_values[i] = np.sum(weights * values[indices]) / np.sum(weights)

    return smoothed_values


def inpolygon(xq, yq, xv, yv):
    shape = xq.shape
    xq = xq.reshape(-1)
    yq = yq.reshape(-1)
    xv = xv.reshape(-1)
    yv = yv.reshape(-1)
    q = [(xq[i], yq[i]) for i in range(xq.shape[0])]
    p = path.Path([(xv[i], yv[i]) for i in range(xv.shape[0])])
    return p.contains_points(q).reshape(shape)


def rbf_interpolate(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y):
    # Observation coordinates and data
    x_vec = np.linspace(estimation.model.meshes[mesh_idx].x_perimeter.min(), estimation.model.meshes[mesh_idx].x_perimeter.max(), n_grid_x)
    y_vec = np.linspace(estimation.model.meshes[mesh_idx].y_perimeter.min(), estimation.model.meshes[mesh_idx].y_perimeter.max(), n_grid_y)
    x_mat, y_mat = np.meshgrid(x_vec, y_vec)
    y_mat = y_mat
    centroids_lon = estimation.model.meshes[mesh_idx].centroids[:, 0]
    centroids_lat = estimation.model.meshes[mesh_idx].centroids[:, 1]
    centroids_val = fill_value

    # Package for RBFInterpolator
    xgrid = np.stack((x_mat, y_mat))
    xflat = xgrid.reshape(2, -1).T
    xobs = np.vstack((centroids_lon, centroids_lat)).T
    yobs = centroids_val
    yflat = RBFInterpolator(xobs, yobs, kernel="cubic", smoothing=0.01, epsilon=1.5)(
        xflat
    )
    ygrid = yflat.reshape(n_grid_x, n_grid_y)
    return xgrid, ygrid


def interpolate_to_mesh_masked_grid(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y):
    # Interpolate (and extrapolate!) onto regular grid
    xgrid, ygrid = rbf_interpolate(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y)

    # Set values outside of mesh perimeter to np.nan
    xflat = xgrid.reshape(2, -1).T
    inpolygon_vals = inpolygon(
        xflat[:, 0], xflat[:, 1], estimation.model.meshes[mesh_idx].x_perimeter, estimation.model.meshes[mesh_idx].y_perimeter
    )
    inpolygon_vals = np.reshape(inpolygon_vals, (n_grid_x, n_grid_y))
    ygrid[~inpolygon_vals] = np.nan
    
    # Provde sensible names
    x_grid = xgrid[0]
    y_grid = xgrid[1]
    fill_value_grid = ygrid

    return x_grid, y_grid, fill_value_grid    

    # return xgrid, ygrid

# Read model and create default plotting parameters


In [None]:
estimation = celeri.Estimation.from_disk("../../wna/runs/0000000040/")

# Get start and stop indices for tde slip rate estimations for each mesh

In [None]:
n_meshes, mesh_start_idx, mesh_end_idx =  get_mesh_indices(estimation)

# Load and modify plotting parameter dataclass

In [None]:
# Get a default plotting parameter dataclass
p = celeri.get_default_plotting_options(estimation.model.config, estimation, estimation.model.station)

# Modify plotting parameters below
p.fontsize = 12

# Plot basic model inputs

In [None]:
celeri.plot_input_summary(estimation.model)

# Plot model estimation summary


In [None]:
celeri.plot_estimation_summary(estimation)

# Observed velocities

In [None]:
# Observed velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.model.station.east_vel,
    estimation.model.station.north_vel,
    arrow_scale=1.0,
)

# Model velocities

In [None]:
# Model velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel,
    estimation.north_vel,
    arrow_scale=1.0,
)

# Residual velocities

In [None]:
# Residual velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_residual,
    estimation.north_vel_residual,
    arrow_scale=0.9,
)

# Rotation velocities

In [None]:
# Rotation velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_rotation,
    estimation.north_vel_rotation,
    arrow_scale=1.0,
)

# Fully locked segment velocities

In [None]:
# Elastic segment velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_elastic_segment,
    estimation.north_vel_elastic_segment,
    arrow_scale=0.25,
)

# Mesh velocities

In [None]:
# Elastic tde velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_tde,
    estimation.north_vel_tde,
    arrow_scale=1.00,
)

# Total elastic velocities

In [None]:
# Elastic tde velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_tde + estimation.east_vel_elastic_segment,
    estimation.north_vel_tde + estimation.north_vel_elastic_segment,
    arrow_scale=1.00,
)

# Strain velocities

In [None]:
# Mogi velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_block_strain_rate,
    estimation.north_vel_block_strain_rate,
    arrow_scale=0.01,
)

# Mogi velocities

In [None]:
# Mogi velocities
plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
plot_vel_arrows_elements(
    p,
    estimation.model.station,
    estimation.east_vel_mogi,
    estimation.north_vel_mogi,
    arrow_scale=0.10,
)

# Residual velocity histogram

In [None]:
# Residual velocities
residual_velocity_vector = np.concatenate(
    (estimation.east_vel_residual, estimation.north_vel_residual)
)
mean_average_error = np.mean(np.abs(residual_velocity_vector))
mean_squared_error = np.sum(residual_velocity_vector **2.0)

# Histogram parameters
n_bins = 100
bin_edge_min = -15
bin_edge_max = 15
bins = np.linspace(bin_edge_min, bin_edge_max, n_bins)

# Create histogram of residual velocities
plt.figure(figsize=(8, 4))
plt.hist(
    residual_velocity_vector,
    bins,
    histtype="stepfilled",
    color="lightblue",
    edgecolor="k",
)
plt.xlim([bin_edge_min, bin_edge_max])
plt.xlabel("residual velocity (mm/yr)", fontsize=p.fontsize)
plt.ylabel("N", fontsize=p.fontsize)
plt.title(f"MAE = {mean_average_error:.2f} (mm/yr), MSE = {mean_squared_error:.2f} (mm$^2$/yr$^2$)", fontsize=p.fontsize)
plt.tick_params(labelsize=p.fontsize)
plt.show()

# Residual velocity scatter plot

In [None]:
# Scatter plot for velocity estimate errors
mae_station = np.abs(estimation.east_vel_residual) + np.abs(
    estimation.north_vel_residual
)
fig = plt.figure(figsize=p.figsize_vectors)
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Residual velocity mean average error
plt.scatter(
    estimation.model.station.lon,
    estimation.model.station.lat,
    s=25,
    edgecolors="k",
    c=mae_station,
    cmap="YlOrRd",
    linewidths=0.1,
)
plt.clim(0, 10)

# Draw white background rectangle
rect = mpatches.Rectangle(
    p.key_rectangle_anchor,
    p.key_rectangle_width,
    p.key_rectangle_height,
    fill=True,
    color=p.key_background_color,
    linewidth=p.key_linewidth,
    ec=p.key_edgecolor,
)
plt.gca().add_patch(rect)

#plt.xlim(236, 245)
#plt.ylim(32, 41)

#plt.xlim(230, 250)
#plt.ylim(31, 45)

#plt.xlim(230, 250)
#plt.ylim(31, 45)
#plt.xlim(240, 243)
#plt.ylim(33, 36)


plt.ylim(31, 40)
plt.xlim(237, 247)


plt.show()

# Plot draws of coupling on mesh 1

In [None]:
vmin = 0
vmax = 1
center = 0.5
cmap = "seismic"
kind = "dip_slip"
mesh_idx = 0

fig, axes = plt.subplots(2, 4, figsize=(12, 8))

for ax, draw in zip(axes.flat, range(0, 800, 100)):
    fill_value = estimation.mcmc_trace.posterior[f"coupling_{mesh_idx}_{kind}"].isel(chain=0, draw=draw)
    mesh = estimation.model.meshes[mesh_idx]
    
    import matplotlib.colors as colors
    import matplotlib.collections as mc
    
    norm = colors.TwoSlopeNorm(vmin=vmin, vcenter=center, vmax=vmax)
    
    x_coords = mesh.points[:, 0]
    y_coords = mesh.points[:, 1]
    vertex_array = np.asarray(mesh.verts)
    
    xy = np.c_[x_coords, y_coords]
    verts = xy[vertex_array]
    pc = mc.PolyCollection(
        verts,
        edgecolor="none",
        cmap=cmap,
        norm=norm,
    )
    
    pc.set_array(fill_value)
    pc.set_clim(vmin, vmax)
    ax.add_collection(pc)
    
    x_edge = x_coords[mesh.ordered_edge_nodes[:, 0]]
    y_edge = y_coords[mesh.ordered_edge_nodes[:, 0]]
    x_edge = np.append(x_edge, x_coords[mesh.ordered_edge_nodes[0, 0]])
    y_edge = np.append(y_edge, y_coords[mesh.ordered_edge_nodes[0, 0]])
    ax.plot(x_edge, y_edge, color="black", linewidth=1)
    ax.set_aspect("equal")

In [None]:
# Posterior draw, or `None` for the posterior mean
draw = 100

if draw is not None:
    est = estimation.mcmc_draw(draw=draw, chain=0)
else:
    est = estimation

p.lon_range = (230, 240)
p.lat_range = (35, 41)
p.figsize_vectors = (24, 20)

plt.figure(figsize=p.figsize_vectors)
plot_segments(p, est.model.segment)
#plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])


if True:
    slip_rate_width_scale = 0.25
    #slip_rate_width_scale = 300
    for i in range(len(est.model.segment)):
        slip = est.strike_slip_rates[i]
        if slip < 0:
            #if grad.values[i] > 0:
            plt.plot(
                [est.model.segment.lon1[i], est.model.segment.lon2[i]],
                [est.model.segment.lat1[i], est.model.segment.lat2[i]],
                "-", color="tab:orange",
                linewidth=slip_rate_width_scale * slip,
                #linewidth=slip_rate_width_scale * grad.values[i],
            )
        else:
            plt.plot(
                [est.model.segment.lon1[i], est.model.segment.lon2[i]],
                [est.model.segment.lat1[i], est.model.segment.lat2[i]],
                "-", color="tab:blue",
                #linewidth=slip_rate_width_scale * grad.values[i],
                linewidth=slip_rate_width_scale * slip,
            )
    
    # Legend
    blue_segments = mlines.Line2D(
        [],
        [],
        color="tab:orange",
        marker="s",
        linestyle="None",
        markersize=10,
        label="right-lateral (10 mm/yr)",
    )
    red_segments = mlines.Line2D(
        [],
        [],
        color="tab:blue",
        marker="s",
        linestyle="None",
        markersize=10,
        label="left-lateral (10 mm/yr)",
    )
    plt.legend(
        handles=[blue_segments, red_segments],
        loc="lower left",
        fontsize=p.fontsize,
        framealpha=1.0,
        edgecolor="k",
    ).get_frame().set_boxstyle("Square")


if False:
    plt.scatter(
        est.model.station.lon,
        est.model.station.lat,
        c=grad.values.sum(1),
        cmap="seismic",
    )
    plt.clim(-0.1, 0.1)

#plt.xlim(236, 245)
#plt.ylim(32, 41)

#plt.xlim(230, 250)
#plt.ylim(31, 45)

#plt.xlim(230, 250)
#plt.ylim(31, 45)
#plt.xlim(240, 243)
#plt.ylim(33, 36)


plt.ylim(31, 40)
plt.xlim(237, 247)



if True:
    plot_vel_arrows_elements(
        p,
        est.model.station,
        #estimation.model.station.east_vel,
        #estimation.model.station.north_vel,
        est.station.model_east_vel_residual,
        est.station.model_north_vel_residual,
        #estimation.station.model_east_vel_rotation,
        #estimation.station.model_north_vel_rotation,
        #estimation.station.model_east_vel,
        #estimation.station.model_north_vel,
        arrow_scale=1,
    )

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(20, 8))
minval, maxval = estimation.segment[estimation.segment.ss_rate_bound_flag != 0][["ss_rate_bound_min", "ss_rate_bound_max"]].values.T
ax1.fill_between(np.arange(len(minval)), minval, maxval)
ax1.plot(np.arange(len(minval)), estimation.segment[estimation.segment.ss_rate_bound_flag != 0]["model_strike_slip_rate"], color="C1")
ax1.axhline(0, color="black", zorder=-100)
ax1.set_ylabel("strike slip")
ax1.set_xlabel("segment")

minval, maxval = estimation.segment[estimation.segment.ds_rate_bound_flag != 0][["ds_rate_bound_min", "ds_rate_bound_max"]].values.T
ax2.fill_between(np.arange(len(minval)), minval, maxval)
ax2.plot(np.arange(len(minval)), estimation.segment[estimation.segment.ds_rate_bound_flag != 0]["model_dip_slip_rate"], color="C1")
ax2.axhline(0, color="black", zorder=-100)
ax2.set_ylabel("dip slip")
ax2.set_xlabel("segment")

fig.tight_layout();

# Fault slip rates (strike-slip)

In [None]:
# Plot estimated strike-slip rates
slip_rate_width_scale = 0.25

plt.figure(figsize=p.figsize_vectors)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Plot fault slip rates
for i in range(len(estimation.model.segment)):
    if estimation.strike_slip_rates[i] < 0:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:orange",
            linewidth=slip_rate_width_scale * estimation.strike_slip_rates[i],
        )
    else:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:blue",
            linewidth=slip_rate_width_scale * estimation.strike_slip_rates[i],
        )

# Legend
blue_segments = mlines.Line2D(
    [],
    [],
    color="tab:orange",
    marker="s",
    linestyle="None",
    markersize=10,
    label="right-lateral (10 mm/yr)",
)
red_segments = mlines.Line2D(
    [],
    [],
    color="tab:blue",
    marker="s",
    linestyle="None",
    markersize=10,
    label="left-lateral (10 mm/yr)",
)
plt.legend(
    handles=[blue_segments, red_segments],
    loc="lower right",
    fontsize=p.fontsize,
    framealpha=1.0,
    edgecolor="k",
).get_frame().set_boxstyle("Square")

In [None]:
# Scaling factor for slip rate widths
slip_rate_width_scale = 0.25

# Plot estimated dip/tensile-slip rates
plt.figure(figsize=p.figsize_vectors)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

for i in range(len(estimation.model.segment)):
    if estimation.dip_slip_rates[i] > 0:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:orange",
            linewidth=slip_rate_width_scale * estimation.dip_slip_rates[i],
        )
    else:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:blue",
            linewidth=slip_rate_width_scale * estimation.dip_slip_rates[i],
        )

for i in range(len(estimation.model.segment)):
    if estimation.tensile_slip_rates[i] < 0:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:orange",
            linewidth=slip_rate_width_scale * estimation.tensile_slip_rates[i],
        )
    else:
        plt.plot(
            [estimation.model.segment.lon1[i], estimation.model.segment.lon2[i]],
            [estimation.model.segment.lat1[i], estimation.model.segment.lat2[i]],
            "-", color="tab:blue",
            linewidth=slip_rate_width_scale * estimation.tensile_slip_rates[i],
        )


# Legend
black_segments = mlines.Line2D(
    [],
    [],
    color="tab:orange",
    marker="s",
    linestyle="None",
    markersize=10,
    label="convergence (10 mm/yr)",
)
red_segments = mlines.Line2D(
    [],
    [],
    color="tab:blue",
    marker="s",
    linestyle="None",
    markersize=10,
    label="extension (10 mm/yr)",
)
plt.legend(
    handles=[black_segments, red_segments],
    loc="lower right",
    fontsize=p.fontsize,
    framealpha=1.0,
    edgecolor="k",
).get_frame().set_boxstyle("Square")

# Quick mesh plots

In [None]:
# TODO: Modify with axes flags

# Strike-slip
mesh_idx = 0
plot_mesh(estimation.model.meshes[mesh_idx], estimation.tde_strike_slip_rates[mesh_idx])
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
plt.gca().set_aspect("equal")
plt.title("ss elastic")
plt.show()

# Dip-slip
plot_mesh(estimation.model.meshes[mesh_idx], estimation.tde_dip_slip_rates[mesh_idx])
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
plt.gca().set_aspect("equal")
plt.title("ds elastic")
plt.show()


# Elastically contributing dip-slip rates (filled contour)

In [None]:
# Estimated dip-slip rates
mesh_idx = 0
max_vel = 60
levels = np.linspace(-max_vel, max_vel, 11)
n_grid_x = 500
n_grid_y = 500

# Interpolate onto regular grid
fill_value = estimation.tde_dip_slip_rates[mesh_idx]
x_grid, y_grid, fill_value_grid = interpolate_to_mesh_masked_grid(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y)

# Plot geometry with contoured mesh values
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Plot contours
ch = plt.contourf(x_grid, y_grid, fill_value_grid, cmap="Spectral_r", levels=levels, extend="both")
plt.contour(
    x_grid, y_grid, fill_value_grid, colors="k", linestyles="solid", linewidths=0.25, levels=levels
)

plt.plot(estimation.model.meshes[mesh_idx].x_perimeter, estimation.model.meshes[mesh_idx].y_perimeter, "-k", linewidth=0.0)
plt.gca().set_aspect("equal", adjustable="box")
cax = inset_axes(
    plt.gca(),
    width="20%",
    height="30%",
    # loc="upper right",
    loc="lower left",
    bbox_to_anchor=(0.70, 0.05, 0.10, 0.95),  # Position in axes fraction
    bbox_transform=plt.gca().transAxes,
    borderpad=0,
)
cbar = plt.colorbar(ch, cax=cax, ticks=[-max_vel, 0, max_vel], label="v (mm/yr)")

# Kinematic rates (filled contour)

In [None]:
# Estimated dip-slip kinematic rates
mesh_idx = 0
max_vel = 60
levels = np.linspace(-max_vel, max_vel, 11)
n_grid_x = 500
n_grid_y = 500

# Interpolate onto regular grid
# fill_value = estimation.tde_dip_slip_rates_kinematic[mesh_start_idx[mesh_idx] : mesh_end_idx[mesh_idx]]
fill_value = estimation.tde_dip_slip_rates_kinematic[mesh_idx]

x_grid, y_grid, fill_value_grid = interpolate_to_mesh_masked_grid(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y)

# Plot geometry with contoured mesh values
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Plot contours
ch = plt.contourf(x_grid, y_grid, fill_value_grid, cmap="Spectral_r", levels=levels, extend="both")
plt.contour(
    x_grid, y_grid, fill_value_grid, colors="k", linestyles="solid", linewidths=0.25, levels=levels
)

plt.plot(estimation.model.meshes[mesh_idx].x_perimeter, estimation.model.meshes[mesh_idx].y_perimeter, "-k", linewidth=0.0)
plt.gca().set_aspect("equal", adjustable="box")
cax = inset_axes(
    plt.gca(),
    width="20%",
    height="30%",
    # loc="upper right",
    loc="lower left",
    bbox_to_anchor=(0.70, 0.05, 0.10, 0.95),  # Position in axes fraction
    bbox_transform=plt.gca().transAxes,
    borderpad=0,
)
cbar = plt.colorbar(ch, cax=cax, ticks=[-max_vel, 0, max_vel], label="v (mm/yr)")

# Kinematic rates (smoothed, filled contour)

In [None]:
# Estimated dip-slip kinematic rates
mesh_idx = 0
max_vel = 60
levels = np.linspace(-max_vel, max_vel, 11)
n_grid_x = 500
n_grid_y = 500

# Interpolate onto regular grid
# NOT SURE WHAT FIELD SHOULD BE PLOTTED HERE
fill_value = estimation.tde_dip_slip_rates_kinematic_smooth[mesh_idx]
x_grid, y_grid, fill_value_grid = interpolate_to_mesh_masked_grid(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y)

# Plot geometry with contoured mesh values
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Plot contours
ch = plt.contourf(x_grid, y_grid, fill_value_grid, cmap="Spectral_r", levels=levels, extend="both")
plt.contour(
    x_grid, y_grid, fill_value_grid, colors="k", linestyles="solid", linewidths=0.25, levels=levels
)

plt.plot(estimation.model.meshes[mesh_idx].x_perimeter, estimation.model.meshes[mesh_idx].y_perimeter, "-k", linewidth=0.0)
plt.gca().set_aspect("equal", adjustable="box")
cax = inset_axes(
    plt.gca(),
    width="20%",
    height="30%",
    # loc="upper right",
    loc="lower left",
    bbox_to_anchor=(0.70, 0.05, 0.10, 0.95),  # Position in axes fraction
    bbox_transform=plt.gca().transAxes,
    borderpad=0,
)
cbar = plt.colorbar(ch, cax=cax, ticks=[-max_vel, 0, max_vel], label="v (mm/yr)")

# Coupling (filled contour)

In [None]:
# Estimated dip-slip coupling
levels = np.linspace(0, 1, 11)
n_grid_x = 500
n_grid_y = 500

# Interpolate onto regular grid
fill_value = estimation.tde_dip_slip_rates_coupling[mesh_idx]
x_grid, y_grid, fill_value_grid = interpolate_to_mesh_masked_grid(mesh_idx, fill_value, estimation, n_grid_x, n_grid_y)

# Plot geometry with contoured mesh values
plot_segments(p, estimation.model.segment)
plot_common_elements(p)
celeri.plot_land(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])
celeri.plot_coastlines(p.lon_range[0], p.lat_range[0], p.lon_range[1], p.lat_range[1])

# Plot contours
ch = plt.contourf(x_grid, y_grid, fill_value_grid, cmap="plasma_r", levels=levels, extend="both")
plt.contour(
    x_grid, y_grid, fill_value_grid, colors="k", linestyles="solid", linewidths=0.25, levels=levels
)

plt.plot(estimation.model.meshes[mesh_idx].x_perimeter, estimation.model.meshes[mesh_idx].y_perimeter, "-k", linewidth=0.0)
plt.gca().set_aspect("equal", adjustable="box")
cax = inset_axes(
    plt.gca(),
    width="20%",
    height="30%",
    # loc="upper right",
    loc="lower left",
    bbox_to_anchor=(0.70, 0.05, 0.10, 0.95),  # Position in axes fraction
    bbox_transform=plt.gca().transAxes,
    borderpad=0,
)
cbar = plt.colorbar(ch, cax=cax, ticks=[0, 1], label="coupling")

In [None]:
import arviz

In [None]:
ess = arviz.ess(estimation.mcmc_trace.posterior)

In [None]:
ess.drop_vars(["mogi", "mogi_param", "mogi_raw", "block_strain_rate", "block_strain_rate_raw", "block_strain_rate_param"]).min().to_pandas()

# SQP convergence history

In [None]:
celeri.plot_iterative_convergence(estimation)
