In [None]:
import celeri
from loguru import logger
import matplotlib.pyplot as plt
import numpy as np
import arviz as az

logger.disable("celeri")

config_files = {
    "japan": "../data/config/japan_config.json",
    "north_america": "../data/config/wna_config.json",
}

In [None]:
# Create a model with coupling limits [-2, 1]
config = celeri.Config.from_file(config_files["japan"])
model_confs = []
for conf in config.mesh_params[:-1]:
    conf.elastic_constraints_ds = [None, None]
    conf.elastic_constraints_ss = [None, None]
    conf.coupling_constraints_ds = [-2, 1]
    conf.coupling_constraints_ss = [-2, 1]
config.mesh_params[-1].elastic_constraints_ds = [-100, 100]
config.mesh_params[-1].elastic_constraints_ss = [-100, 100]
model_neg2_one = celeri.build_model(config)

# Create a model with coupling limits [0, 1]
config = celeri.Config.from_file(config_files["japan"])
model_confs = []
for conf in config.mesh_params[:-1]:
    conf.elastic_constraints_ds = [None, None]
    conf.elastic_constraints_ss = [None, None]
    conf.coupling_constraints_ds = [0, 1]
    conf.coupling_constraints_ss = [0, 1]
config.mesh_params[-1].elastic_constraints_ds = [-100, 100]
config.mesh_params[-1].elastic_constraints_ss = [-100, 100]
model_zero_one = celeri.build_model(config)

In [None]:
%%time
estimation_mcmc_zero_one = celeri.solve_mcmc(model_zero_one, sample_kwargs={"chains": 1})
estimation_mcmc_zero_one.to_disk("estimation_mcmc_zero_one.zarr")

In [None]:
%%time
estimation_mcmc_neg2_one = celeri.solve_mcmc(model_neg2_one, sample_kwargs={"chains": 1})
estimation_mcmc_neg2_one.to_disk("estimation_mcmc_neg2_one.zarr")

In [None]:
estimation_mcmc_neg2_one = celeri.Estimation.from_disk("estimation_mcmc_neg2_one.zarr")
estimation_mcmc_zero_one = celeri.Estimation.from_disk("estimation_mcmc_zero_one.zarr")

In [None]:
# We can access individual MCMC draws as estimation objects
est1 = estimation_mcmc_neg2_one.mcmc_draw(0, 0)
est2 = estimation_mcmc_neg2_one.mcmc_draw(50, 0)
est3 = estimation_mcmc_neg2_one.mcmc_draw(150, 0)

In [None]:
az.ess(estimation_mcmc_neg2_one.mcmc_trace).min()

In [None]:
az.ess(estimation_mcmc_zero_one.mcmc_trace).min()

In [None]:
az.plot_trace(estimation_mcmc_neg2_one.mcmc_trace.posterior.sigma ** 2);
az.plot_trace(estimation_mcmc_zero_one.mcmc_trace.posterior.sigma ** 2);

In [None]:
for name, estimation in {0: est1, 50: est2, 150: est3}.items():
    trace = estimation.mcmc_trace

    fig = plt.figure(layout='constrained', figsize=(15, 7))
    subfigs = fig.subfigures(2, 1, wspace=0.07)
    fig.suptitle(name)

    for kind, subfig in zip(["dip_slip", "strike_slip"], subfigs):
        ax1, ax2, ax3 = subfig.subplots(1, 3)
        ax1.set_aspect("equal")
        ax2.set_aspect("equal")
        ax3.set_aspect("equal")
        mesh_idxs = range(3)
        meshes = [estimation.model.meshes[mesh_idx] for mesh_idx in mesh_idxs]
        subfig.suptitle(kind)
        ax1.set_title("kinematic")
        ax2.set_title("coupling")
        ax3.set_title("elastic")
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"kinematic_{mesh_idx}_{kind}"].sel(chain=0, draw=name).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax1,
            vmin=-100,
            vmax=100,
            cmap="coolwarm",
            center=0,
        )
        ax1.set_aspect("equal")
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"coupling_{mesh_idx}_{kind}"].sel(chain=0, draw=name).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax2,
            vmin=-2,
            vmax=1,
            center=0,
            cmap="coolwarm",
        )
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"elastic_{mesh_idx}_{kind}"].sel(chain=0, draw=name).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax3,
            vmin=-100,
            vmax=100,
            cmap="coolwarm",
            center=0,
        )

In [None]:
for name, estimation in {"zero_one": estimation_mcmc_zero_one, "neg2_one": estimation_mcmc_neg2_one}.items():
    trace = estimation.mcmc_trace

    fig = plt.figure(layout='constrained', figsize=(15, 7))
    subfigs = fig.subfigures(2, 1, wspace=0.07)
    fig.suptitle(name)

    for kind, subfig in zip(["dip_slip", "strike_slip"], subfigs):
        ax1, ax2, ax3 = subfig.subplots(1, 3)
        ax1.set_aspect("equal")
        ax2.set_aspect("equal")
        ax3.set_aspect("equal")
        mesh_idxs = range(3)
        meshes = [estimation.model.meshes[mesh_idx] for mesh_idx in mesh_idxs]
        subfig.suptitle(kind)
        ax1.set_title("kinematic")
        ax2.set_title("coupling")
        ax3.set_title("elastic")
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"kinematic_{mesh_idx}_{kind}"].mean(["draw", "chain"]).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax1,
            vmin=-100,
            vmax=100,
            cmap="coolwarm",
            center=0,
        )
        ax1.set_aspect("equal")
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"coupling_{mesh_idx}_{kind}"].mean(["draw", "chain"]).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax2,
            vmin=-2,
            vmax=1,
            cmap="coolwarm",
            center=0,
        )
        celeri.plot_meshes(
            meshes,
            np.r_[*[
                trace.posterior[f"elastic_{mesh_idx}_{kind}"].mean(["draw", "chain"]).values
                for mesh_idx in mesh_idxs
            ]],
            ax=ax3,
            vmin=-100,
            vmax=100,
            cmap="coolwarm",
            center=0,
        )

In [None]:
celeri.plot_estimation_summary(estimation_mcmc_zero_one.model, estimation_mcmc_zero_one)

In [None]:
celeri.plot_estimation_summary(estimation_mcmc_neg2_one.model, estimation_mcmc_neg2_one)

In [None]:
estimation_sqp2_zero_one = celeri.solve_sqp2(model_zero_one)
estimation_sqp2_neg2_one = celeri.solve_sqp2(model_neg2_one)

In [None]:
celeri.plot_estimation_summary(estimation_sqp2_zero_one.model, estimation_sqp2_zero_one)

In [None]:
celeri.plot_estimation_summary(estimation_sqp2_neg2_one.model, estimation_sqp2_neg2_one)