In [None]:
import pandas as pd
import os
import json
import zipfile
from io import StringIO, BytesIO
import numpy as np
from scipy.interpolate import griddata, NearestNDInterpolator
from sklearn.neighbors import KDTree

In [None]:
def interpolate_data(
    df: pd.DataFrame,
    grid_params: dict,
    k: int = 3,
    power: float = 1.0,
    average_duplicates: bool = True,
) -> pd.DataFrame:
    """
    Regrid all numeric variables in df onto the 2D grid defined by `grid_params`.

    - Axis names are inferred from grid_params keys (expects exactly 2 axes).
    - Those axis columns must exist in df.
    - All numeric columns except the axes are interpolated (2D IDW).
    """
    axes = list(grid_params.keys())
    if len(axes) != 2:
        raise ValueError(f"grid_params must have exactly 2 axes, got: {axes}")

    a0, a1 = axes[0], axes[1]  # preserve naming from JSON
    if a0 not in df.columns or a1 not in df.columns:
        raise KeyError(
            f"Axis columns {a0!r}, {a1!r} must exist in df. "
            f"df columns: {list(df.columns)}"
        )

    # numeric variables to interpolate = all numeric except axes
    all_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
    variables = [c for c in all_numeric if c not in (a0, a1)]
    if not variables:
        raise ValueError("No numeric variables found to interpolate (besides the axes).")

    # Optionally average duplicate axis points
    if average_duplicates:
        work = (df[[a0, a1] + variables]
                .groupby([a0, a1], as_index=False)
                .mean(numeric_only=True))
    else:
        work = df[[a0, a1] + variables].copy()

    # Build query grid
    xi = np.linspace(grid_params[a0]["min"], grid_params[a0]["max"], int(grid_params[a0]["n"]))
    yi = np.linspace(grid_params[a1]["min"], grid_params[a1]["max"], int(grid_params[a1]["n"]))
    X, Y = np.meshgrid(xi, yi, indexing="xy")
    grid_points = np.column_stack([X.ravel(), Y.ravel()])

    # KDTree query
    pts = work[[a0, a1]].to_numpy()
    n_pts = len(pts)
    if n_pts == 0:
        raise ValueError("No input points to interpolate.")

    k_eff = min(int(k), n_pts)
    tree = KDTree(pts)
    dist, ind = tree.query(grid_points, k=k_eff)

    # Weights
    if power == 0:
        weights = np.full_like(dist, 1.0 / dist.shape[1], dtype=float)
    else:
        with np.errstate(divide="ignore"):
            w = 1.0 / (np.power(dist, power) + 1e-12)

        # exact matches: give all weight to the zero-distance neighbor(s)
        zero_rows = np.any(dist < 1e-12, axis=1)
        if np.any(zero_rows):
            w[zero_rows] = 0.0
            zmask = dist[zero_rows] < 1e-12
            w[zero_rows] = zmask / zmask.sum(axis=1, keepdims=True)

        row_sums = w.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1.0
        weights = w / row_sums

    # Interpolate
    out = {a0: grid_points[:, 0], a1: grid_points[:, 1]}
    for var in variables:
        vals = work[var].to_numpy()
        neigh_vals = vals[ind]
        out[var] = np.sum(weights * neigh_vals, axis=1)

    return pd.DataFrame(out)

# test interpolation

In [None]:
def read_slip_long(
    path: str,
    *,
    z_name: str = "z",
    t_name: str = "t",
    slip_name: str = "slip",
    max_slip_rate_name: str = "max_slip_rate",
    keep_max_slip_rate: bool = True,
    n_header_zeros: int = 2,
    min_fields_for_axis_row: int = 10,
) -> pd.DataFrame:
    """
    Read SEAS-style slip evolution file into tidy/long form.

    File format:
      - comment lines start with '#'
      - optional schema lines (e.g., 'z', 't', 'max_slip_rate', 'slip')
      - one axis row: many floats, includes `n_header_zeros` leading zeros
      - data rows: t, max_slip_rate, slip(z1..zN)

    Returns a DataFrame with columns: t_name, z_name, slip_name (+ optional max_slip_rate_name).
    """
    with open(path, "r") as f:
        lines = f.readlines()

    # Find the axis row: first non-comment line with lots of fields AND numeric-looking content
    def is_float(s: str) -> bool:
        try:
            float(s)
            return True
        except ValueError:
            return False

    x_idx = None
    for i, ln in enumerate(lines):
        if ln.lstrip().startswith("#"):
            continue
        parts = ln.split()
        if len(parts) > min_fields_for_axis_row and all(is_float(p) for p in parts):
            x_idx = i
            break
    if x_idx is None:
        raise ValueError("Could not find axis row (many numeric fields).")

    # Parse z positions (includes leading zeros per spec)
    z = np.array(lines[x_idx].split(), dtype=float)
    if n_header_zeros:
        z = z[n_header_zeros:]

    expected_len = len(z) + 2  # t + max_slip_rate + slip(z...)
    rows = []
    for ln in lines[x_idx + 1:]:
        if ln.lstrip().startswith("#"):
            continue
        parts = ln.split()
        if len(parts) != expected_len:
            continue
        # ensure numeric row
        if not all(is_float(p) for p in parts):
            continue
        rows.append(parts)

    if not rows:
        raise ValueError("No data rows found matching expected length.")

    data = np.array(rows, dtype=float)
    t = data[:, 0]
    max_sr = data[:, 1]
    slip = data[:, 2:]  # shape (nt, nz)

    # Long/tidy
    out = {
        t_name: np.repeat(t, len(z)),
        z_name: np.tile(z, len(t)),
        slip_name: slip.reshape(-1),
    }
    if keep_max_slip_rate:
        out[max_slip_rate_name] = np.repeat(max_sr, len(z))

    return pd.DataFrame(out)

# usage
tidy = read_slip_long("../resources/bp8-qd/slip2.dat", keep_max_slip_rate=True)
tidy.head()

In [None]:
df = tidy
df

In [None]:
template_path = os.path.join("../resources/benchmark_templates/bp8-qd-slip2Evolution.json")
with open(template_path, 'r') as f:
    template = json.load(f)
    for file_info in template['files']:
        expected_structure = file_info
        if "grid" in expected_structure:
            interpolated_df = interpolate_data(df, expected_structure['grid'])
            break #Just run one for the testing
interpolated_df

# Plots

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
# === inputs ===
df = interpolated_df
var = "slip"
variable_dict = {"name": var, "unit": "m"}

# --- Build gridded field safely ---
grid = df.pivot(index="z", columns="t", values=var).sort_index().sort_index(axis=1)

z_unique = grid.index.to_numpy()
t_unique = grid.columns.to_numpy()
Z = grid.to_numpy()  # shape (nz, nt)

zmin = float(np.nanmin(Z))
zmax = float(np.nanmax(Z))

# --- Cross-section at z ≈ 0 ---
z0_target = 0.0
iz0 = int(np.argmin(np.abs(z_unique - z0_target)))
z0 = z_unique[iz0]

profile_vs_t = Z[iz0, :]  # slip(t) at fixed z

fig = make_subplots(
    rows=1, cols=2,
    column_widths=[0.78, 0.22],
    horizontal_spacing=0.06,
    specs=[[{"type": "heatmap"}, {"type": "xy"}]],
    subplot_titles=(
        f"{variable_dict['name']} [{variable_dict['unit']}] on (t, z)",
        f"Cross-section at z={z0:.2f} m"
    )
)

# Heatmap: x=t, y=z
fig.add_trace(
    go.Heatmap(
        x=t_unique,
        y=z_unique,
        z=Z,
        zmin=zmin, zmax=zmax,
        colorscale="RdBu_r",
        colorbar=dict(title=f"{variable_dict['name']} ({variable_dict['unit']})")
    ),
    row=1, col=1
)

# Horizontal line at z ≈ 0
fig.add_shape(
    type="line",
    x0=t_unique.min(), x1=t_unique.max(),
    y0=z0, y1=z0,
    line=dict(width=2, dash="dash"),
    row=1, col=1
)

# Cross-section: slip vs t at fixed z
fig.add_trace(
    go.Scatter(
        x=t_unique,
        y=profile_vs_t,
        mode="lines",
        name=f"{variable_dict['name']} @ z={z0:.2f} m"
    ),
    row=1, col=2
)

# Labels
fig.update_xaxes(title_text="t (s)", row=1, col=1)
fig.update_yaxes(title_text="z (m)", row=1, col=1)

fig.update_xaxes(title_text="t (s)", row=1, col=2)
fig.update_yaxes(title_text=f"{variable_dict['name']} ({variable_dict['unit']})", row=1, col=2)

fig.update_layout(
    template="plotly_white",
    title="BP8-QD Slip Evolution",
    margin=dict(l=60, r=20, t=60, b=50),
    showlegend=False,
    height=800
)

fig.show()

In [None]:
# === inputs ===
df = interpolated_df
var = "slip"
variable_dict = {"name": var, "unit": "m"}

# --- Build gridded field safely ---
grid = df.pivot(index="z", columns="t", values=var).sort_index().sort_index(axis=1)

z_unique = grid.index.to_numpy()
t_unique = grid.columns.to_numpy()
Z = grid.to_numpy()  # shape (nz, nt)

zmin = float(np.nanmin(Z))
zmax = float(np.nanmax(Z))

# --- Cross-section at t ≈ 1,000,000 s ---
t0_target = 1_000_000.0
it0 = int(np.argmin(np.abs(t_unique - t0_target)))
t0 = t_unique[it0]

profile_vs_z = Z[:, it0]  # slip(z) at fixed t

fig = make_subplots(
    rows=1, cols=2,
    column_widths=[0.78, 0.22],
    horizontal_spacing=0.06,
    specs=[[{"type": "heatmap"}, {"type": "xy"}]],
    subplot_titles=(
        f"{variable_dict['name']} [{variable_dict['unit']}] on (t, z)",
        f"Cross-section at t={t0:.0f} s"
    )
)

# Heatmap: x=t, y=z
fig.add_trace(
    go.Heatmap(
        x=t_unique,
        y=z_unique,
        z=Z,
        zmin=zmin, zmax=zmax,
        colorscale="RdBu_r",
        colorbar=dict(title=f"{variable_dict['name']} ({variable_dict['unit']})")
    ),
    row=1, col=1
)

# Vertical line at t ≈ 1e6
fig.add_shape(
    type="line",
    x0=t0, x1=t0,
    y0=z_unique.min(), y1=z_unique.max(),
    line=dict(width=2, dash="dash"),
    row=1, col=1
)

# Cross-section: slip vs z at fixed t
fig.add_trace(
    go.Scatter(
        x=profile_vs_z,
        y=z_unique,
        mode="lines",
        name=f"{variable_dict['name']} @ t={t0:.0f} s"
    ),
    row=1, col=2
)

# Labels
fig.update_xaxes(title_text="t (s)", row=1, col=1)
fig.update_yaxes(title_text="z (m)", row=1, col=1)

fig.update_xaxes(title_text=f"{variable_dict['name']} ({variable_dict['unit']})", row=1, col=2)
fig.update_yaxes(title_text="z (m)", row=1, col=2)

fig.update_layout(
    template="plotly_white",
    title="BP8-QD Slip Evolution",
    margin=dict(l=60, r=20, t=60, b=50),
    showlegend=False,
    height=800
)

fig.show()