# In this notebook, I tried to reproduce the results of the following papers

Kubo et al., 2009 (Seismological and experimental constraints on metastable phase transformations
and rheology of the Mariana slab)

Note:
- some of the equations still don't show correctly. Also they seem to mess up the title when they are presented in the "OUTLINE".

In [None]:
import os
import sys
import re
import subprocess
from pathlib import Path
import pandas as pd
import numpy as np
from scipy.interpolate import UnivariateSpline
from matplotlib import pyplot as plt

root_path = os.path.join(Path().resolve().parent.parent.parent)

if str(os.path.abspath(root_path)) not in sys.path:
    sys.path.insert(0, str(os.path.abspath(root_path)))

# Include this pakage
import hamageolib.utils.plot_helper as plot_helper
import hamageolib.research.haoyuan_2d_subduction.metastable as Meta

from hamageolib.utils.exception_handler import my_assert
from hamageolib.utils.handy_shortcuts_haoyuan import Mute

# Directory path of this notebook
base_dir = Path().resolve()

# For dumping results
results_dir = os.path.join(root_path, "dtemp")
if not os.path.isdir(results_dir):
    os.mkdir(results_dir)

# For scripts (e.g. bash, paraview)
SCRIPT_DIR = os.path.join(root_path, "scripts")

# Utility functions

In [None]:
from vtk.util.numpy_support import vtk_to_numpy
from scipy.interpolate import interp1d
from scipy.spatial import cKDTree
    
year = 365.0 * 24.0 * 3600.0  # Seconds in one year

# define a function for round values
round_values = lambda values: [round(x) for x in values]

def idw_interpolation(points, values, query_points, k=5, power=1):
    """
    Perform inverse distance weighting (IDW) interpolation.

    Args:
        points (np.ndarray): Coordinates of known data points (N x D).
        values (np.ndarray): Values at the known data points (N,).
        query_points (np.ndarray): Coordinates of query points (M x D).
        k (int): Number of nearest neighbors to consider for interpolation.
        power (int): Power parameter for the inverse distance weighting.

    Returns:
        np.ndarray: Interpolated values at the query points (M,).
    """
    # Ensure values are a NumPy array
    values = np.asarray(values)
    
    # Build a KDTree for fast nearest-neighbor lookup
    tree = cKDTree(points)

    # Find the k nearest neighbors for each query point
    distances, indices = tree.query(query_points, k=k)

    # Handle zero distances (avoid division by zero)
    distances = np.maximum(distances, 1e-12)

    # Compute weights as the inverse distance raised to the power
    weights = 1 / distances**power

    # Normalize weights
    weights /= weights.sum(axis=1, keepdims=True)

    # Interpolate the values
    print("indices: ", indices) # debug
    interpolated_values = np.sum(weights * values[indices], axis=1)

    return interpolated_values


def extract_contour_coordinates(xv, yv, t_sub_grid, levels, spacing=None):
    """
    Extracts the x and y coordinates of contours from a meshed grid and field data.

    Args:
        xv (numpy.ndarray): 2D array of x-coordinates.
        yv (numpy.ndarray): 2D array of y-coordinates.
        t_sub_grid (numpy.ndarray): 2D array of field data on the grid.
        levels (list or array): Contour levels (values) to extract.
        spacing (float, optional): Desired spacing between contour points. Defaults to None.

    Returns:
        dict: A dictionary with levels as keys and (x, y) coordinate arrays as values.
    """
    # Create the contour object
    contours = plt.contour(xv, yv, t_sub_grid, levels=levels)

    my_assert(len(contours.collections) == len(levels), ValueError,\
              "Values of collections (%d) do not equal the provided levels(%d)"\
                  % (len(contours.collections), len(levels)))

    contour_coordinates = {}
    for i, level in enumerate(levels):
        x_coords = []
        y_coords = []

        # Extract paths for the given level
        for path in contours.collections[i].get_paths():
            vertices = path.vertices  # Extract contour points
            xs, ys = vertices[:, 0], vertices[:, 1]

            if spacing is not None and len(xs) > 1:
                # Calculate cumulative distances along the contour
                distances = np.sqrt(np.diff(xs)**2 + np.diff(ys)**2)
                cumulative_distances = np.concatenate([[0], np.cumsum(distances)])

                # Create interpolation functions for x and y
                interp_x = interp1d(cumulative_distances, xs, kind='cubic', bounds_error=False, fill_value="extrapolate")
                interp_y = interp1d(cumulative_distances, ys, kind='cubic', bounds_error=False, fill_value="extrapolate")

                # Generate new cumulative distances with the desired spacing
                cumulative_distances_with_spacing = np.arange(0, cumulative_distances[-1], spacing)

                # Interpolate x and y for evenly spaced points
                xs = interp_x(cumulative_distances_with_spacing)
                ys = interp_y(cumulative_distances_with_spacing)

            x_coords.extend(xs)
            y_coords.extend(ys)

        contour_coordinates[level] = (np.array(x_coords), np.array(y_coords))
    
    plt.close()  # Close the plot to avoid displaying it
    return contour_coordinates

def offset_curve(X, Y, d):
    # Compute tangents as finite differences
    dX = np.gradient(X)
    dY = np.gradient(Y)
    
    # Compute normals by rotating tangents 90 degrees
    normals = np.array([dY, -dX]).T  # Rotate tangent vectors
    norm_length = np.linalg.norm(normals, axis=1, keepdims=True)
    normals = normals / norm_length  # Normalize to unit vectors
    
    # Offset points by the normal vectors scaled by the distance d
    X_offset = X + d * normals[:, 0]
    Y_offset = Y + d * normals[:, 1]
    
    return X_offset, Y_offset

def process_segments(slab_segments, n_spacing=10):
    """
    Processes segments to generate a continuous curve based on lengths and dip angles.

    Args:
        slab_segments (list): List of segment dictionaries with 'length' and 'angle' keys.
        n_spacing (int): Number of points per segment for interpolation.

    Returns:
        tuple: A tuple containing:
            - lengths (numpy.ndarray): Accumulated lengths at each point.
            - depths (numpy.ndarray): Accumulated depths at each point.
            - dip_angles (numpy.ndarray): Corresponding dip angles at each point.
            - Xs (numpy.ndarray): X coordinates of each point.
    """
    # Calculate the total number of points in advance
    total_points = sum(n_spacing for _ in slab_segments)

    # Preallocate arrays
    lengths = np.zeros(total_points)
    depths = np.zeros(total_points)
    dip_angles = np.zeros(total_points)
    Xs = np.zeros(total_points)

    accumulated_length = 0
    accumulated_depth = 0
    accumulated_x = 0
    current_index = 0

    for segment in slab_segments:
        # Extract length and angle pair
        length = segment['length']
        angle_pair = segment['angle']

        # Generate n_spacing points for angles and step lengths
        segment_angles = np.linspace(angle_pair[0], angle_pair[1], n_spacing)
        step_length = length / (n_spacing - 1)  # Incremental length per step

        # Calculate incremental depths and X increments for each step
        segment_depth_increments = step_length * (
            np.sin(np.radians(segment_angles[:-1])) + np.sin(np.radians(segment_angles[1:]))
        ) / 2.0  # Midpoint rule for depth integration

        segment_x_increments = step_length * (
            np.cos(np.radians(segment_angles[:-1])) + np.cos(np.radians(segment_angles[1:]))
        ) / 2.0  # Midpoint rule for X integration

        # Accumulated lengths and depths
        segment_accumulated_lengths = accumulated_length + np.arange(n_spacing) * step_length
        segment_depths = accumulated_depth + np.concatenate([np.array([0]), np.cumsum(segment_depth_increments)])
        segment_xs = accumulated_x + np.concatenate([np.array([0]), np.cumsum(segment_x_increments)])

        # Write to preallocated arrays
        lengths[current_index:current_index + n_spacing] = segment_accumulated_lengths
        depths[current_index:current_index + n_spacing] = segment_depths
        dip_angles[current_index:current_index + n_spacing] = np.radians(segment_angles)
        Xs[current_index:current_index + n_spacing] = segment_xs

        # Update accumulations for the next segment
        accumulated_length = segment_accumulated_lengths[-1]
        accumulated_depth = segment_depths[-1]
        accumulated_x = segment_xs[-1]
        current_index += n_spacing

    return lengths, depths, dip_angles, Xs


def distances_to_curve(Xs, Ys, x, y):
    """
    Computes the signed distance from each point in (x, y) to a curve defined by (Xs, Ys).

    Positive distance: below the curve (smaller y).
    Negative distance: above the curve (larger y).

    Args:
        Xs (numpy.ndarray): X-coordinates of the curve points (1D array).
        Ys (numpy.ndarray): Y-coordinates of the curve points (1D array).
        x (numpy.ndarray): X-coordinates of the points (1D array).
        y (numpy.ndarray): Y-coordinates of the points (1D array).

    Returns:
        numpy.ndarray: The signed distance from each point in (x, y) to the curve (1D array).
    """
    # Ensure x and y are numpy arrays
    x = np.asarray(x)
    y = np.asarray(y)

    # Reshape (x, y) to allow broadcasting against (Xs, Ys)
    x = x[:, np.newaxis]  # Shape (n_points, 1)
    y = y[:, np.newaxis]  # Shape (n_points, 1)

    # Compute squared distances for all combinations
    distances_squared = (Xs - x) ** 2 + (Ys - y) ** 2  # Shape (n_points, len(Xs))

    # Find the index of the closest point on the curve for each query point
    min_indices = np.argmin(distances_squared, axis=1)

    # Use min_indices to directly extract the minimum distances
    min_distances = np.sqrt(distances_squared[np.arange(len(x)), min_indices])

    # Determine if the point is above or below the curve
    closest_Ys = Ys[min_indices]
    signs = np.where(y.flatten() < closest_Ys, 1.0, -1.0)

    # Apply the sign to the distances
    signed_distances = signs * min_distances
    return signed_distances

# Equilibrium phase transition

In [None]:
# equilibrium phase transition for 410 km
# dV - change in volume fraction
# See whether these are still needed
# dV_ol_wd - m^3 / mol, difference in volume between phases
# V_initial - m^3 / mol, for olivine, estimation at 410 km
# PT410 = {"P": 13.5e9, "T": 1740.0, "cl": 2e6, "dV": 0.052, "dV_ol_wd": 2.4e-6, "V_initial": 35.17e-6}
PT410 = {"P": 13.5e9, "T": 1740.0, "cl": 2e6}

# Reproduce Literature

## Hosoya 2005 (The Kinematics)

### Summary: Growth Rate and Timescale Calculations Based on Hosoya 2006

This notebook contains Python code for calculating the growth rate and critical timescales of phase transformations based on the kinetic models described in *Hosoya et al., 2006*. The calculations rely on the Arrhenius law and consider the effects of pressure, temperature, and water content.

### Constants and Assumptions

- Universal gas constant: \( R = 8.31446 \, \text{J/mol·K} \)
- Activation enthalpy: \( \Delta H = 274 \, \text{kJ/mol} \)
- Activation volume: \( V^* = 3.3 \times 10^{-6} \, \text{m}^3/\text{mol} \)
- Volume difference: \( \Delta V = 2.4 \times 10^{-6} \, \text{m}^3/\text{mol} \)
- Water concentration scaling exponent: \( n = 3.2 \)

### Functions Overview

1. **`growth_rate_hosoya_06_eq2_P1(P, T, Coh)`**:
   - Calculates the growth rate using Equation 2 from Hosoya (2006) for given pressure (`P`), temperature (`T`), and water concentration (`Coh`).
   - **Parameters**:
     - `P`: Pressure (Pa)
     - `T`: Temperature (K)
     - `Coh`: Water concentration (wt.ppm H2O)
   - **Returns**: Growth rate in meters per second.

2. **`growth_rate_hosoya_06_eq2(P, P_eq, T, Coh)`**:
   - Extends the growth rate calculation to handle cases where the pressure is an array and determines whether the pressure exceeds the equilibrium pressure (`P_eq`).
   - **Parameters**:
     - `P`: Pressure (float or ndarray)
     - `P_eq`: Equilibrium pressure (Pa)
     - `T`: Temperature (K)
     - `Coh`: Water concentration (wt.ppm H2O)
   - **Returns**: Growth rate as a float or array.

3. **`timescale_hosoya_06(P, P_eq, growth_rate, Coh)`**:
   - Computes the critical timescale for phase transitions using growth rate and equilibrium properties.
   - **Parameters**:
     - `P`: Pressure (Pa)
     - `P_eq`: Equilibrium pressure (Pa)
     - `growth_rate`: Growth rate (m/s)
     - `Coh`: Water concentration (wt.ppm H2O)
   - **Returns**: Critical timescale (seconds).

### Visualization

The notebook includes plots to visualize the growth rate variations:
- **Pressure vs. Growth Rate**: Explores the dependence of growth rate on pressure.
- **Temperature vs. Growth Rate**: Shows how growth rate changes with temperature.
- **OH Content vs. Growth Rate**: Investigates the effect of water concentration on growth rate.

In [None]:
from matplotlib import gridspec
from scipy.integrate import cumtrapz

def growth_rate_P1(P, T, Coh):
    """
    Calculate the growth rate following Equation 2 in Hosoya 2006.

    Parameters:
    - P (float): Pressure in Pascals.
    - T (float): Temperature in Kelvin.
    - Coh (float): Concentration of water in weight parts per million (wt.ppm H2O).

    Returns:
    - float: The growth rate calculated using the given parameters.
    """
    R = 8.31446  # J / mol*K, universal gas constant

    # Constants based on Hosoya 2006
    A = np.exp(-18.0)  # m s-1 wt.ppmH2O^(-3.2)
    n = 3.2
    dHa = 274.0e3  # J / mol, activation enthalpy
    Vstar = 3.3e-6  # m^3 / mol, activation volume

    growth_rate_part = A * Coh**n * np.exp(-(dHa + P * Vstar) / (R * T))

    return growth_rate_part


def growth_rate_metastable(P, P_eq, T, Coh):
    """
    Calculate growth rate using Equation 2 from Hosoya 2006 for metastable conditions.

    Parameters:
    - P (float or ndarray): Actual pressure (Pa).
    - P_eq (float or ndarray): Equilibrium pressure (Pa).
    - T (float or ndarray): Temperature (K).
    - Coh (float): Water concentration (wt.ppm H2O).

    Returns:
    - float or ndarray: Growth rate for the given conditions.
    """
    if type(T) in [float, np.float64]:
        if Ps > P_eq:
            growth_rate = growth_rate(P, P_eq, T, Coh)
        else:
            growth_rate = 0.0
    elif type(T) == np.ndarray:
        assert P.shape == T.shape
        assert P_eq.shape == T.shape
        growth_rate = np.zeros(T.shape)
        mask = (P > P_eq)  # Check metastable condition
        growth_rate[mask] = growth_rate(P[mask], P_eq[mask], T[mask], Coh)
    else:
        raise TypeError("T must be float or ndarray")
    return growth_rate


def growth_rate(P, P_eq, T, Coh):
    """
    Calculate the growth rate following Equation 2 in Hosoya 2006, considering 
    pressure and temperature variations.

    Parameters:
    - P (float or np.ndarray): Pressure in Pascals.
    - P_eq (float or np.ndarray): Equilibrium pressure in Pascals.
    - T (float or np.ndarray): Temperature in Kelvin.
    - Coh (float): Concentration of water in weight parts per million (wt.ppm H2O).

    Returns:
    - float or np.ndarray: The growth rate for each pressure point.
    """
    R = 8.31446  # J / mol*K, universal gas constant

    # Determine growth rate based on pressure type (float or array)
    if type(P) in [float, np.float64]:
        if P > P_eq:
            dGr = PT410["dV_ol_wd"] * (P - P_eq)
            growth_rate = growth_rate_P1(P, T, Coh) * T * (1 - np.exp(-dGr / (R * T)))
        else:
            growth_rate = 0.0
    elif type(P) == np.ndarray:
        growth_rate = np.zeros(P.shape)
        mask = P > P_eq
        Pm = P[mask]
        Tm = T[mask]
        dGr = PT410["dV_ol_wd"] * (Pm - P_eq[mask])
        growth_rate[mask] = growth_rate_P1(Pm, Tm, Coh) * Tm * (1 - np.exp(-dGr / (R * Tm)))
    else:
        raise TypeError("P must be float or ndarray")

    return growth_rate


def timescale_hosoya_06(P, P_eq, growth_rate, Coh):
    """
    Calculate the critical time scale for phase transition kinetics.

    Parameters:
    - P (float): Pressure in Pascals.
    - P_eq (float): Equilibrium pressure in Pascals.
    - growth_rate (float): Growth rate in m/s.
    - Coh (float): Concentration of water in weight parts per million (wt.ppm H2O).

    Returns:
    - float: The critical timescale in seconds.
    """
    R = 8.31446  # J / mol*K, universal gas constant

    # Constants for calculation
    A = np.exp(-18.0)  # m s-1 wt.ppmH2O^(-3.2)
    A_dot = A * PT410["dV_ol_wd"] / R
    n = 3.2
    dHa = 274.0e3  # J / mol, activation enthalpy
    Vstar = 3.3e-6  # m^3 / mol, activation volume

    Tcr = (dHa + P * Vstar) / R / (np.log(A_dot * Coh**n * (P - P_eq) / growth_rate))

    return Tcr

def MO_Vfraction_classic(growth_rates, ts, da0):

    # Compute cumulative integral of growth rates over time
    integral = np.zeros(growth_rates.shape)  # Initialize integral
    integral[1:] = cumtrapz(growth_rates, ts)  # Cumulative integral

    # Calculate transformed volume fraction
    S = 3.35 / da0  # Surface area per unit volume (1/m)
    V = 1 - np.exp(-2.0 * S * integral)  # Transformed volume fraction

    return V


plot_hosoya_06 = False

if plot_hosoya_06:

    # Visualization of growth rate variations
    fig = plt.figure(tight_layout=True, figsize=(15, 5))
    gs = gridspec.GridSpec(1, 3)

    # Plot growth rate vs Pressure
    ax = fig.add_subplot(gs[0, 0])
    T = 900 + 273.15  # Temperature in Kelvin
    Ps = np.arange(13e9, 16e9, 0.1e9)  # Pressure range in Pascals
    Coh = 1000.0  # Concentration of water in wt.ppm H2O
    growth_rate_part = growth_rate_P1(Ps, T, Coh)
    ax.plot(Ps / 1e9, np.log(growth_rate_part))  # Pressure in GPa
    ax.grid()
    ax.set_xlim([13.0, 16.0])
    ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("Pressure (GPa)")
    ax.set_ylabel("ln(growth_rate/T[1-exp(-dGr/RT)])")

    # Plot growth rate vs Temperature
    ax = fig.add_subplot(gs[0, 1])
    T_invert = np.arange(0.7, 1.1, 0.01)  # 1000/T range
    Ts = 1000.0 / T_invert  # Temperature in Kelvin
    P = 15e9  # Pressure in Pascals
    Coh = 1000.0  # Concentration of water in wt.ppm H2O
    growth_rate_part = growth_rate_P1(P, Ts, Coh)
    ax.plot(T_invert, np.log(growth_rate_part))
    ax.grid()
    ax.set_xlim([0.7, 1.1])
    ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("1000/T (K)")

    # Plot growth rate vs OH content
    ax = fig.add_subplot(gs[0, 2])
    T = 900 + 273.15  # Temperature in Kelvin
    P = 15e9  # Pressure in Pascals
    log10_Cohs = np.arange(2, 4, 0.05)  # Logarithmic OH content range
    Cohs = 10**log10_Cohs  # OH content in wt.ppm H2O
    growth_rate_part = growth_rate_P1(P, T, Cohs)
    ax.semilogx(Cohs, np.log(growth_rate_part))
    ax.grid()
    ax.set_xlim([10**2, 10**4])
    ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("OH content (wt. ppm H2O)")


We also plot the values of grow rate for comparison

In [None]:
if plot_hosoya_06:


    # Visualization of growth rate variations
    fig = plt.figure(tight_layout=True, figsize=(15, 5))
    gs = gridspec.GridSpec(1, 3)

    # Plot growth rate vs Pressure
    ax = fig.add_subplot(gs[0, 0])
    T = 900 + 273.15  # Temperature in Kelvin
    Ps = np.arange(13e9, 16e9, 0.1e9)  # Pressure range in Pascals
    Ts = np.full(Ps.shape, T)
    Ps_eq = np.full(Ps.shape, PT410["P"])
    Coh = 1000.0  # Concentration of water in wt.ppm H2O
    growth_rate = growth_rate(Ps, Ps_eq, Ts, Coh)
    # growth_rate_part = growth_rate_P1(Ps, T, Coh)

    ax.plot(Ps/1e9, np.log(growth_rate))
    ax.grid()
    ax.set_xlabel("Pressure (GPa)")
    ax.set_ylabel("Growth Rate (m/s)")
    ax.set_xlim([13.0, 16.0])


    # Plot growth rate vs Temperature
    ax = fig.add_subplot(gs[0, 1])
    T_invert = np.arange(0.7, 1.1, 0.01)  # 1000/T range
    Ts = 1000.0 / T_invert  # Temperature in Kelvin
    P = 15e9  # Pressure in Pascals
    Ps = np.full(Ts.shape, P)
    Ps_eq = np.full(Ps.shape, PT410["P"])
    Coh = 1000.0  # Concentration of water in wt.ppm H2O
    growth_rate = growth_rate(Ps, Ps_eq, Ts, Coh)
    ax.plot(T_invert, np.log(growth_rate))
    ax.grid()
    ax.set_xlim([0.7, 1.1])
    # ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("1000/T (K)")
    # ax.set_ylabel("Growth Rate (m/s)")


    # Plot growth rate vs OH content
    ax = fig.add_subplot(gs[0, 2])
    T = 900 + 273.15  # Temperature in Kelvin
    P = 15e9  # Pressure in Pascals
    log10_Cohs = np.arange(2, 4, 0.05)  # Logarithmic OH content range
    Cohs = 10**log10_Cohs  # OH content in wt.ppm H2O
    Ts = np.full(Cohs.shape, T)
    Ps = np.full(Cohs.shape, P)
    Ps_eq = np.full(Cohs.shape, PT410["P"])
    growth_rate = growth_rate(Ps, Ps_eq, Ts, Cohs)
    # growth_rate_part = growth_rate_P1(P, T, Cohs)
    ax.semilogx(Cohs, growth_rate)
    ax.grid()
    ax.set_xlim([10**2, 10**4])
#     ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("OH content (wt. ppm H2O)") 

## Kubo et al., 2009

### Interpolation and Visualization of Temperature Profile

This section of the notebook focuses on loading, interpolating, and visualizing a temperature profile (their Fig. 2) based on depth from a CSV file.
They already include the latent heat release from the exthorthemic transition in the thermal model (Fig. 2).

### Key Steps

1. **Data Loading**:
   - Reads the temperature profile from a CSV file (`kubo_2009_T_center.csv`).
   - The file is located in the `data_set` directory within the specified `base_dir`.

2. **Data Extraction**:
   - Extracts depth (in kilometers) and temperature (in degrees Celsius) columns from the loaded dataset.

3. **Interpolation**:
   - Uses `scipy.interpolate.interp1d` to create an interpolation function for the temperature profile.
   - Generates interpolated values for depths at 1 km intervals.

4. **Visualization**:
   - Plots the interpolated temperature profile, with depth on the x-axis and temperature on the y-axis.
   - Includes axis labels and a grid for improved clarity.

### Assumptions and Requirements

- The CSV file must exist at the specified path, and its structure should include two columns: depth and temperature.
- The depth values should be sorted in ascending order for proper interpolation.

In [None]:
from scipy.interpolate import interp1d

reproduce_Kubo_2009 = False

if reproduce_Kubo_2009:

    # Load and plot the temperature profile from a CSV file.
    # Load temperature profile data from the file
    file_T_path = os.path.join(base_dir, "data_set", "kubo_2009_T_center.csv")
    assert(os.path.isfile(file_T_path))  # Ensure the file exists
    data_T = pd.read_csv(file_T_path)  # Read data using pandas

    # Extract depth and temperature columns
    depths = data_T.iloc[:, 0]  # Depth in kilometers
    Ts = data_T.iloc[:, 1]  # Temperature in degrees Celsius

    # Interpolate temperature profile
    T_interp = interp1d(depths, Ts)  # Create an interpolating function

    # Generate interpolated values
    depths1 = np.arange(depths.iloc[0], depths.iloc[-1], 1.0)  # Interpolated depths (km)
    Ts1 = T_interp(depths1)  # Interpolated temperatures (C)

    # Plot the interpolated temperature profile
    fig, ax = plt.subplots()
    ax.plot(depths1, Ts1)  # Depth vs. temperature plot

    # Add plot labels and grid
    ax.set_xlabel("Depth (km)")  # Label for x-axis
    ax.set_ylabel("Temperature (C)")  # Label for y-axis
    ax.grid()  # Add grid for better visualization


### This section calculates and visualizes two types of temperature dependencies in their Fig. 4

1. Non-equilibrium temperature (\( T_{NE} \)) as a function of overpressure.
2. Grain size-dependent temperature (\( T_{GN} \)) as a function of grain size.

#### Equation for V

$V = 1 - exp\left[-2S \int_{0}^t \dot{x}\left(\tau\right)d\tau\right]$

#### Key Steps

1. **Non-equilibrium Temperature ($T_{NE}$)**:
   - Computes $T_{NE}$ for 1000 ppm and 100 ppm water concentrations.
   - Uses an overpressure range ($\Delta P$) from $0.01 \, \text{GPa}$ to $1.0 \, \text{GPa}$.
   - Subduction velocity ($v_{sub}$) and timescale ($t_{sub}$) are used to estimate the growth rate.
   - Visualized as $T_{NE}$ versus overpressure in a log-log plot.

2. **Grain Size-dependent Temperature ($T_{GN}$)**:
   - Computes $T_{GN}$ for 1000 ppm and 100 ppm water concentrations.
   - Uses a grain size range ($d$) from $10^{-6} \, \text{m}$ to $10^{-1} \, \text{m}$.
   - Grain size-dependent growth rate ($v_{GN}$) is calculated using a fixed subduction timescale ($t_{sub}$).
   - Visualized as $T_{GN}$ versus grain size in a log-log plot.

#### Constants and Assumptions

- **Year Conversion**: $1 \, \text{year} = 365 \times 24 \times 3600 \, \text{s}$
- **Equilibrium Conditions**:
  - Equilibrium temperature: $T_{eq,410} = 1760 \, \text{K}$
  - Clapeyron slope: $c_{l,410} = 4 \times 10^6 \, \text{Pa/K}$
  - Equilibrium pressure: $P_{410} = 14 \, \text{GPa}$
- **Subduction Properties**:
  - Grain size: $d = 5 \times 10^{-3} \, \text{m}$
  - Subduction velocity: $v_{sub} = 0.1 / \text{year}$
  - Subduction timescale: $t_{sub} = 10^5 \, \text{years}$

#### Additional notes
* Note their ol-wd curves references the Hosoya 2005 paper, while the post-spinel curve references the Kubo et al., 2002a, 2008 paper.
* S is the area of grain-boundary of parent phase (=3.35/d), d is the grain size of the parental olivine, assuming equidimensional grains (tetrakaidecahedra) [Cahn, 1956]. Note this seems the only place the parental grain size is needed.
* d is assumed to 5 mm in their figure 3.
* tsub, the subduction timescale, is based on $v_{sub}t_{sub} = 10 km$, thus a velocity of 10 cm yields a timescale of ${10}^5$ years.

In [None]:
if reproduce_Kubo_2009:


    # Create a figure with two subplots
    fig = plt.figure(tight_layout=True, figsize=(5, 10))
    gs = gridspec.GridSpec(2, 1)

    # ---- Figure 4a: T_NE (non-equilibrium temperature) vs. overpressure ----
    ax = fig.add_subplot(gs[0, 0])

    P_eq = PT410["P"]  # Equilibrium pressure (Pa)
    lndP = np.arange(-2.0, 0.0, 0.05)  # Logarithmic overpressure range (log10 GPa)
    dP = 10**lndP * 1e9  # Overpressure range (Pa)
    P = P_eq + dP  # Actual pressure range (Pa)
    d = 5e-3  # Grain size (m)
    v_sub = 0.1 / year  # Subduction velocity (m/s)
    t_sub = 1e5 * year  # Subduction timescale (s)
    growth_rate_NE = v_sub  # Growth rate for non-equilibrium

    # Calculate non-equilibrium temperature timescale (T_NE) for 1000 ppm water
    ts_NE = timescale_hosoya_06(P, P_eq, growth_rate_NE, 1000.0)
    ax.semilogx(dP / 1e9, ts_NE - 273.15, label="ol-wd, 1000 ppm")  # Plot curve for 1000 ppm

    # Calculate T_NE for 100 ppm water
    ts_NE_coh100 = timescale_hosoya_06(P, P_eq, growth_rate_NE, 100.0)
    ax.semilogx(dP / 1e9, ts_NE_coh100 - 273.15, label="ol-wd, 100 ppm")  # Plot curve for 100 ppm

    # Configure subplot
    ax.grid()
    ax.legend()
    ax.set_xlim([1e-2, 1.0])  # Overpressure range (GPa)
    ax.set_ylim([800, 2100.0])  # Temperature range (°C)
    ax.set_xlabel("Overpressure (GPa)")
    ax.set_ylabel("T_NE (C)")

    # ---- Figure 4b: T_GN (grain size-dependent temperature) vs. grain size ----
    ax = fig.add_subplot(gs[1, 0])

    logds = np.arange(-6, -1, 0.1)  # Logarithmic grain size range
    ds = 10**logds  # Grain size range (m)
    growth_rate_GN = ds / t_sub  # Grain size-dependent growth rate (m/s)
    dP = 0.5e9  # Overpressure (Pa)
    P = P_eq + dP  # Pressure for grain size calculations (Pa)

    # Calculate grain size-dependent temperature timescale (T_GN) for 1000 ppm water
    ts_GN = timescale_hosoya_06(P, P_eq, growth_rate_GN, 1000.0)
    ax.semilogx(ds, ts_GN - 273.15, label="ol-wd, 1000 ppm")  # Plot curve for 1000 ppm

    # Calculate T_GN for 100 ppm water
    ts_GN_coh100 = timescale_hosoya_06(P, P_eq, growth_rate_GN, 100.0)
    ax.semilogx(ds, ts_GN_coh100 - 273.15, label="ol-wd, 100 ppm")  # Plot curve for 100 ppm

    # Configure subplot
    ax.grid()
    ax.legend()
    ax.set_xlim([1e-6, 1e-1])  # Grain size range (m)
    ax.set_ylim([400, 1700.0])  # Temperature range (°C)
    ax.set_xlabel("Parental grain size (m)")
    ax.set_ylabel("T_GN (C)")


### FIgure 5: Nucleation and Growth Visualization

#### Key Steps

1. **Parameters**:
   - Temperature (\(T\)) and pressure (\(P\)) profiles are calculated for depths from 400 km to 700 km.
   - Grain sizes (\(d\)) are set to \(5 \times 10^{-3} \, \text{m}\) for grain-boundary nucleation and \(5 \times 10^{-5} \, \text{m}\) for intracrystalline nucleation.

2. **Growth Rates**:
   - Growth rates are computed using metastable conditions for varying water concentrations (\(C_{OH}\)).

3. **Transformed Volume Fractions**:
   - The cumulative transformed volume (\(V\)) is calculated as:
     \[
     V = 1 - \exp(-2.0 \cdot S \cdot \text{integral})
     \]
   - \(S\) is the surface area per unit volume.

4. **Plots**:
   - **Grain-boundary nucleation**: \(V\) is plotted against depth for water concentrations of 250, 500, and 750 wt.ppm.
   - **Intracrystalline nucleation**: \(V\) is plotted against depth for water concentrations of 50, 150, and 250 wt.ppm.

#### Additional Notes
* They seem to not include the positive feedback from the latentheat, as they note "Previous studies have revealed that the G-type olivine–ringwoodite transformation quickly completes when it proceeds to the critical volume fraction of 5–10 vol.% due to the positive feedback of the latent heat release (e.g., Rubie and Ross, 1994; Kirby et al., 1996; Mosenfelder et al., 2001)."
* The transformed volume by intracrystallion nucleation should always be bigger, since it assumes a smaller parental grain size, thus a larger nucleation area S.
* Note the slight differences of our curves to theirs (e.g. the 250 wt.ppm H2O curve in figure 5 a, at 700 depth, theirs is around 0.85, while ours is around 0.95)

In [None]:
# figure 5: Grain-boundary and intracrystalline nucleation plots

if reproduce_Kubo_2009:

    from scipy.integrate import cumtrapz
    from matplotlib import gridspec

    # Retrieve the default color cycle for consistent plot coloring
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Reference parameters from Hosoya 2005
    T660 = 873.0  # Temperature at 660 km depth (K)
    Tgrad = 0.6 / 1e3  # Temperature gradient (K/m)
    dPdh = 30e6 / 1e3  # Pressure gradient (Pa/m)
    v_sub = 0.095 / year  # Subduction velocity (m/s)
    depth0 = 400e3  # Initial depth for calculations (m)

    # Define depth, temperature, and pressure profiles
    depths = np.arange(depth0, 700e3, 1e3)  # Depths from 400 to 700 km (m)
    ts = (depths - depth0) / v_sub  # Time required for subduction to each depth (s)
    Ts = T_interp(depths / 1e3) + 273.15  # Temperature interpolated and converted to K
    Ps = (depths - 410e3) * dPdh + PT410["P"]  # Pressure profile (Pa)

    # Equilibrium pressure and corresponding depths
    Ps_eq = (Ts - PT410["T"]) * PT410["cl"] + PT410["P"]  # Equilibrium pressure (Pa)
    depths_eq = 410e3 + (Ts - PT410["T"]) * PT410["cl"] / dPdh  # Equilibrium depth (m)


    # Create a figure for nucleation plots
    fig = plt.figure(tight_layout=True, figsize=(5, 10))
    gs = gridspec.GridSpec(2, 1)

    # ---- Plot 1: Grain-boundary nucleation ----
    ax = fig.add_subplot(gs[0, 0])

    d_ol = 5e-3  # Grain size for olivine (m)
    S = 3.35 / d_ol  # Surface area per unit volume (1/m)

    # Calculate and plot for different water concentrations
    for Coh, color, label in [(250.0, default_colors[2], "250 wt.ppm H2O"),
                            (500.0, default_colors[1], "500 wt.ppm H2O"),
                            (750.0, default_colors[0], "750 wt.ppm H2O")]:
        growth_rates = growth_rate_metastable(Ps, Ps_eq, Ts, Coh)
        integral = np.zeros(depths.shape)  # Initialize integral
        integral[1:] = cumtrapz(growth_rates, ts)  # Cumulative integral
        V = 1 - np.exp(-2.0 * S * integral)  # Transformed volume fraction
        ax.plot(depths / 1e3, V, label=label, color=color)

    # Configure plot
    ax.grid()
    ax.legend()
    ax.set_xlabel("Depth (km)")
    ax.set_ylabel("Transformed Volume Fraction")
    ax.set_title("Grain-boundary nucleation")

    # ---- Plot 2: Intracrystalline nucleation ----
    ax = fig.add_subplot(gs[1, 0])

    d_ol = 5e-5  # Grain size for olivine (m)
    S = 3.35 / d_ol  # Surface area per unit volume (1/m)

    # Calculate and plot for different water concentrations
    for Coh, color, label in [(50.0, default_colors[4], "50 wt.ppm H2O"),
                            (150.0, default_colors[3], "150 wt.ppm H2O"),
                            (250.0, default_colors[2], "250 wt.ppm H2O")]:
        growth_rates = growth_rate_metastable(Ps, Ps_eq, Ts, Coh)
        integral = np.zeros(depths.shape)  # Initialize integral
        integral[1:] = cumtrapz(growth_rates, ts)  # Cumulative integral
        V = 1 - np.exp(-2.0 * S * integral)  # Transformed volume fraction
        ax.plot(depths / 1e3, V, label=label, color=color)

    # Configure plot
    ax.grid()
    ax.legend()
    ax.set_xlabel("Depth (km)")
    ax.set_ylabel("Transformed Volume Fraction")
    ax.set_title("Intracrystalline nucleation")

    plt.show()


### Figure 7: Grain size evolution and viscosity calculation

#### Key Steps

1. **Viscosity Calculation**:
   - The flow law from *Kubo et al., 2008* is used to compute viscosity:
     $$
     \eta = \eta_0 \exp\left(\frac{T_0 - T}{a} + \frac{z}{b} - \left(\frac{z}{c}\right)^2\right)
     $$
     - Parameters:
       - $T$: Temperature (K).
       - $z$: Depth (m).
       - $\eta_0$, reference viscosity.
       - $T_0 = 1873.15 \, \text{K}$, reference temperature.
       - $a = 131.3 \, \text{K}$, $b = 150 \, \text{km}$, $c = 1086 \, \text{km}$.

2. **Grain Size Profiles**:
   - Initial grain sizes:
     - Olivine: $d_{a0} = 5 \, \text{mm}$.
     - Wadleyite: $d_{b0} = 0.1 \, \text{mm}$.
   - Grain size evolution is based on the transformed volume fraction $V$:
     $$
     V = 1 - \exp(-2.0 \cdot S \cdot \text{integral})
     $$
     - $S = \frac{3.35}{d_{a0}}$: Surface area per unit volume (1/m).
     - The integral accumulates growth rates over time.

3. **Switching Conditions**:
   - Grain size switches from olivine to wadleyite when $V > 0.5$.

#### Constants and Assumptions

- **Depth Range**: 400 km to 700 km.
- **Water Concentration**: $C_{OH} = 500 \, \text{wt.ppm}$.
- **Subduction Velocity**: $v_{\text{sub}} = 0.095 \, \text{m/s}$.

#### Additional Notes

* Current state: I failed to reproduce the grain size evolution. I assumed the grain size to be the integration of growth rate, this might cause the problem. It might also be the case there is something wrong in the grain growth mechanism.
* For the viscosity, they used the Peierls creep in Karato et al., 2001 and Diffusion creep in Frost and Ashby, I didn't go further into these references.

In [None]:
# Figure 7
def flow_law_Kubo_2008(T, z):
    """
    Compute viscosity using the flow law from Kubo et al. 2008.

    Parameters:
    - T (float): Temperature (K).
    - z (float): Depth (m).

    Returns:
    - float: Viscosity (Pa·s).
    """
    # Constants for the flow law
    a = 131.3  # Temperature scaling factor (K)
    b = 150.0e3  # Depth linear term (m)
    c = 1086e3  # Depth quadratic term (m)
    T0 = 1600 + 273.15  # Reference temperature (K)
    eta0 = 1e-19  # Reference viscosity (Pa·s)

    # Calculate viscosity
    eta = eta0 * np.exp((T0 - T) / a + z / b - (z / c)**2.0)

    return eta

if reproduce_Kubo_2009:

    # Initial parameters
    da0 = 5e-3  # Initial grain size for olivine (m)
    db0 = 1e-7  # Initial grain size for wadleyite (m)
    Coh = 500  # Water concentration (wt.ppm H2O)

    # Define depth, temperature, and pressure profiles
    depths = np.arange(depth0, 700e3, 1e3)  # Depths from 400 to 700 km (m)
    ts = (depths - depth0) / v_sub  # Subduction time to each depth (s)
    Ts = T_interp(depths / 1e3) + 273.15  # Temperature interpolated and converted to K
    Ps = (depths - 410e3) * dPdh + PT410["P"]  # Pressure profile (Pa)

    # Equilibrium pressure and corresponding depths
    Ps_eq = (Ts - PT410["T"]) * PT410["cl"] + PT410["P"]  # Equilibrium pressure (Pa)

    # Initial grain size profile
    ds = np.full(depths.size, da0)  # Start with constant grain size (m)

    # Calculate growth rates under metastable conditions
    growth_rates = growth_rate_metastable(Ps, Ps_eq, Ts, Coh)

    # Compute cumulative integral of growth rates over time
    integral = np.zeros(depths.shape)  # Initialize integral
    integral[1:] = cumtrapz(growth_rates, ts)  # Cumulative integral

    # Calculate transformed volume fraction
    S = 3.35 / da0  # Surface area per unit volume (1/m)
    V = 1 - np.exp(-2.0 * S * integral)  # Transformed volume fraction

    # Update grain size profile
    db = db0 + integral  # Wadleyite grain size evolution (m)

    # Switch grain size based on transformed volume fraction
    mask = (V > 0.5)  # Condition where transformed fraction exceeds 50%
    ds[mask] = db[mask]  # Update grain size for transformed regions

    # Create a figure for grain size plots
    fig = plt.figure(tight_layout=True, figsize=(5, 10))
    gs = gridspec.GridSpec(2, 1)

    # ---- Plot 1: Grain Size vs Depth ----
    ax = fig.add_subplot(gs[0, 0])
    ax.semilogy(depths / 1e3, ds / 1e-6, label="grain size (olivine)")  # Grain size (micron)
    ax.semilogy(depths / 1e3, db / 1e-6, label="grain size (wadleyite)")  # Wadleyite grain size (micron)

    # Configure plot
    ax.grid()
    ax.legend()
    ax.set_xlabel("Depth (km)")
    ax.set_ylabel("Grain size (micron)")


## Tetzlaff & Schmeling 09

### Import their temperature profile in figure 1

Here I only plot their curves with no latent heat effects

### Check the value of their growth rate

Note using their documented values results in a growth rate too low (compared to the same plot in the Hosoya paper). This would in turn result in a 0 fraction of wb contents.

In [None]:
from matplotlib import gridspec

def Tprofile_TS09_fig1_warm(depths):

    depth0 = 370.28496710020283e3; T0 = 684.1150723737485 + 273.15 # m, K
    depth1 = 698.9440717961519e3; T1 = 912.82119751198 + 273.15 # m, K

    Ts = (depths - depth0) / (depth1 - depth0) * T1 + (depths - depth1) / (depth0 - depth1) * T0

    return Ts

def Tprofile_TS09_fig1_cold(depths):

    depth0 = 352.5238657769547e3; T0 = 472.0332858291416 + 273.15 # m, K
    depth1 = 694.6273339699476e3; T1 = 711.0368878272202 + 273.15 # m, K

    Ts = (depths - depth0) / (depth1 - depth0) * T1 + (depths - depth1) / (depth0 - depth1) * T0

    return Ts

def growth_rate_tetzlaff_schmeling_09(Ts, Ps, depths, depth_eq):

    R = 8.31446  # J / mol*K, universal gas constant
    k0 = 2005.0
    # k0 = 20**5.0
    dHa = 350e3 # j / mol
    Va = 1.3e-5 # m^3 / mol
    Lz = 0.5393
    dGr = Lz * (depths - depth_eq) # consistent with their enthalpy calculation
    growth_rate = k0 * Ts * np.exp(-(dHa + Ps * Va) / R / Ts) * (1 - np.exp(-dGr / R / Ts))

    return growth_rate


    
def TS09_check():
    
    R = 8.31446  # J / mol*K, universal gas constant
    dPdh = 30e6/1e3 # Pa/m

    # Visualization of growth rate variations
    fig = plt.figure(tight_layout=True, figsize=(15, 5))
    gs = gridspec.GridSpec(1, 3)

    # Plot growth rate vs Pressure
    ax = fig.add_subplot(gs[0, 0])
    T = 900 + 273.15  # Temperature in Kelvin
    Ps = np.arange(13e9, 16e9, 0.1e9)  # Pressure range in Pascals
    depths = (Ps - PT410["P"]) / dPdh + 410e3
    depths_eq = 410e3 + (T - PT410["T"]) * PT410["cl"] / dPdh
    
    Lz = 0.5393
    dGr = Lz * (depths - depths_eq) # consistent with their enthalpy calculation
    growth_rate = growth_rate_tetzlaff_schmeling_09(T, Ps, depths, depths_eq)
    # Coh = 1000.0  # Concentration of water in wt.ppm H2O
    ax.plot(Ps / 1e9, np.log(growth_rate/(T*(1-np.exp(-dGr/R/T)))))  # Pressure in GPa
    ax.grid()
    ax.set_xlim([13.0, 16.0])
    # ax.set_ylim([-34.0, -22.0])
    ax.set_xlabel("Pressure (GPa)")
    ax.set_ylabel("ln(growth_rate/T[1-exp(-dGr/RT)])")

    # # Plot growth rate vs Temperature
    # ax = fig.add_subplot(gs[0, 1])
    # T_invert = np.arange(0.7, 1.1, 0.01)  # 1000/T range
    # Ts = 1000.0 / T_invert  # Temperature in Kelvin
    # P = 15e9  # Pressure in Pascals
    # Coh = 1000.0  # Concentration of water in wt.ppm H2O
    # growth_rate_part = growth_rate_P1(P, Ts, Coh)
    # ax.plot(T_invert, np.log(growth_rate_part))
    # ax.grid()
    # ax.set_xlim([0.7, 1.1])
    # ax.set_ylim([-34.0, -22.0])
    # ax.set_xlabel("1000/T (K)")

    # # Plot growth rate vs OH content
    # ax = fig.add_subplot(gs[0, 2])
    # T = 900 + 273.15  # Temperature in Kelvin
    # P = 15e9  # Pressure in Pascals
    # log10_Cohs = np.arange(2, 4, 0.05)  # Logarithmic OH content range
    # Cohs = 10**log10_Cohs  # OH content in wt.ppm H2O
    # growth_rate_part = growth_rate_P1(P, T, Cohs)
    # ax.semilogx(Cohs, np.log(growth_rate_part))
    # ax.grid()
    # ax.set_xlim([10**2, 10**4])
    # ax.set_ylim([-34.0, -22.0])
    # ax.set_xlabel("OH content (wt. ppm H2O)")


TS09_check()

In [None]:
q_depths = np.arange(300e3, 700e3, 1e3)

q_Ts_warm = Tprofile_TS09_fig1_warm(q_depths)
q_Ts_cold = Tprofile_TS09_fig1_cold(q_depths)

fig = plt.figure(tight_layout=True, figsize=(5, 10))
gs = gridspec.GridSpec(2, 1)
ax = fig.add_subplot(gs[0, 0])

ax.plot(q_depths/1e3, q_Ts_cold - 273.15)
ax.plot(q_depths/1e3, q_Ts_warm - 273.15)

ax.set_xlim([350, 700])
ax.set_ylim([450, 1000])

ax.set_xlabel("Depth (km)")
ax.set_ylabel("Temperature (C)")

ax.grid()


# calculate MO kinetics
# Parameters for equilibrium at the 410 km depth phase boundary
dPdh = 30e6/1e3 # Pa/m
d_ol = 5e-3 # m
    
# Equilibrium pressure and corresponding depths
v_h = 0.05/year # m / s

q_ts = q_depths / v_h
q_Ps = (q_depths - 410e3)*dPdh + PT410["P"]
q_depths_eq_cold = 410e3 + (q_Ts_cold - PT410["T"]) * PT410["cl"] / dPdh  # Equilibrium depth (m)

q_growth_rate_cold = growth_rate_tetzlaff_schmeling_09(q_Ts_cold, q_Ps, q_depths, q_depths_eq_cold)
V_cold = MO_Vfraction_classic(q_growth_rate_cold, q_ts, d_ol)

# Plot the wd contents
ax1 = fig.add_subplot(gs[1, 0])

ax1.plot(q_depths/1e3, V_cold)
ax1.set_xlim([350, 700])
ax1.set_ylim([0.0, 1.0])

ax1.grid()


plt.show()

# Kinetics from DaBler etal., 1996 and used in Yoshioka etal., 2015

## Nucleation Theory

(Supplementary Material: Put the final equation and the formula of $\Delta G_{c}$ here)

The nucleation rate is expressed as:

$$
I = K_0 T \exp\left(-\frac{\Delta H_a + P V^*}{RT}\right) \exp\left(-\frac{\Delta G^*}{kT} \right)
$$

$$
K_0 = \frac{N k}{h} = \frac{1.75 \times 10^{28} \times 1.38 \times 10^{-23}}{6.626 \times 10^{-34}} \approx 3.65 \times 10^{38}~\mathrm{K \cdot s^{-1} \cdot m^{-3}}
$$

The critical Gibbs free energy $\Delta G^*$  varied by assumptions of homogenous nucleation and heterogenous nucleation:$\Delta G_{c}$ is the change in total Gibbs free energy for homogeneous nucleation:

$$
\Delta G_{hom}^* = \dfrac{16\pi^2 \gamma^3V_m^2}{3\left(\Delta G_d + \epsilon\right)^2}
$$

$$
\Delta G_{het}^* = f_s \Delta G_{hom}^*
$$

where $\gamma$ is the surface energy from forming the spinel structure from olivine. $V_m$ is the molar volume of spinel, and $f_s$ is a shape factor decreasing the value of the critical free energy change from the homogenous value when heterogenous nucleation effectively wets the grain boundary and lowers the total amount of surface energy 

In [None]:
# In this notebook, we import this function from the metastable.py script
# Add the origina function here

# def nucleation_rate(P, T, P_eq):
#     """
#     Compute the nucleation rate using Equation (10) from Yoshioka et al. (2015).
    
#     Parameters:
#     - T (float): Temperature in Kelvin.
#     - P (float): Pressure in Pa
#     - delta_G_v (float): Free energy change per volume in J/m^3.

#     Constants
#     - gamma (float): Surface energy in J/m^2 (default: 0.0506).
#     - K0 (float): Pre-exponential factor in s^-1 m^-2 K^-1 (default: 3.54e38).
#     - Q_a (float): Activation energy for growth in J/mol (default: 355e3).
#     - k (float): Boltzmann constant in J/K (default: 1.38e-23).
#     - R (float): Universal gas constant in J/(mol*K) (default: 8.314).
#     - dV_ol_wd (float): difference in mole volume between phase.
#     - V_initial (float): for olivine, estimation at 410 km, mole volume
    
#     Returns:
#     - I (float): Nucleation rate in s^-1 m^-2.
#     """
#     gamma=0.0506; K0=3.54e38; dH_a=344e3; V_star=4e-6; k=1.38e-23; R=8.314
#     dV_ol_wd = 2.4e-6; V_initial = 35.17e-6

#     if type(P) in [float, np.float64]:
#         assert(P >= P_eq)
#     elif type(P) == np.ndarray:
#         assert(np.min(P - P_eq) >= 0.0)
#     else:
#         raise TypeError("P must be float or ndarray")

#     delta_G_v = dV_ol_wd / V_initial * (P - P_eq)

#     # print("P_eq = %.2f GPa, dGr_vs = %.4e" % (P_eq/1e9, delta_G_v)) # debug

#     # Compute the homogeneous nucleation activation energy
#     delta_G_hom = (16 * np.pi * gamma**3) / (3 * (delta_G_v)**2)
    
#     # Compute the nucleation rate
#     Q_a = dH_a + P * V_star 
#     I = K0 * T * np.exp(-delta_G_hom / (k * T)) * np.exp(-Q_a / (R * T))
    
#     return I

## Site saturation

(Supplementary Material: Put the equation of site saturation and the nondimensionalize number for time here)

The time of site saturation is defined as:
$$t_s =  \left(I_s(P, T) Y(P, T)^2 \right)^{-1/3} = \left(\frac{I_v(P, T) Y(P, T)^2}{S_0} \right)^{-1/3}$$
Before site saturation ($t < t_s$), both nucleation and grain growth contribute to the transformation kinetics. After site saturation ($t > t_s$), nucleation becomes ineffective.

Additionally, time is nondimensionalized as:
$$t = \frac{L^2}{\kappa} \tau$$
where $L$ is a characteristic length scale (taken as 100~km) and $\kappa$ is the thermal diffusivity.

In [None]:
# def calculate_sigma_s(I_PT, Y_PT, d_0, **kwargs):
#     """
#     Calculate the dimensionless time (sigma_s) for the phase transformation process.

#     Parameters:
#     - I_PT (float): Nucleation rate as a function of pressure and temperature (s^-1 m^-3).
#     - Y_PT (float): Growth rate as a function of pressure and temperature (m/s).
#     - d_0 (float): Grain size of olivine (m).

#     Returns:
#     - sigma_s (float): Dimensionless time for site saturation.
#     """
#     kappa = kwargs.get("kappa", 1e-6) # Thermal diffusivity (m^2/s).
#     D = kwargs.get("D", 100e3) # slab thickness
#     # Compute the dimensionless time
#     sigma_s = (kappa / D**2) * ((I_PT * Y_PT**2 * d_0) / 6.7)**(-1/3)
#     return sigma_s


## Equation (19): Avrami Number Calculation

We follow by defining the Avrami number as

$$
Av = \left(\frac{D^2}{\kappa}\right)^4 \cdot I_{max}(P, T) Y_{max}^3(P, T)
$$

In [None]:
# def calculate_avrami_number(I_max, Y_max, **kwargs):
#     """
#     Calculate the Avrami number (Av) using the corrected Equation (19).
    
#     Parameters:
#     - I_max (float): Maximum nucleation rate in s^-1 m^-2.
#     - Y_max (float): Maximum growth rate in m/s.
    
#     Returns:
#     - Av (float): Avrami number (dimensionless).
#     """
#     kappa = kwargs.get("kappa", 1e-6) # Thermal diffusivity (m^2/s).
#     D = kwargs.get("D", 100e3) # slab thickness
#     # Compute the Avrami number
#     Av = (D**2 / kappa)**4 * I_max * Y_max**3
#     return Av

## Kinetic equations

(Supplementary Material: put the extended volume fraction and the kinetic equations here; screenshot a non-dimensional equation.)

In this system, new grains nucleate at the surface of existing grains and grow kinetically.
This leads to an increase in the so-called extended volume fraction $\tilde{V}$, which assumes no overlap between grains.
The true volume fraction $V$ is related to $\tilde{V}$ through the transformation (Avrami,1941):
$$
V = 1 - \exp\left(-\tilde{V}\right)
$$
$$
\frac{d\tilde{V}}{dt} = 4 S Y(t), \qquad t < t_s
$$
$$
\frac{dS}{dt} = \pi D Y(t)
$$
$$
\frac{dD}{dt} = 2 N Y(t)
$$
$$
\frac{dN}{dt} = I_v
$$

After site saturation ($t > t_s$), nucleation becomes ineffective, and the extended volume fraction evolves according to:
$$
\frac{d\tilde{V}}{dt} = S_0 Y(P, T)
$$

where $S$ is the total grain area per unit volume, $D$ is the total grain size per unit volume, and $N$ is the number of new grains per unit volume.

![](./figure/nondimensional_1.png)

Where:
- \(X_3(s)\), \(X_2(s)\), \(X_1(s)\), \(X_0(s)\): Represent the total grain volume, total grain area, total grain diameter, and number of grains, respectively.
- \(Av\): Avrami number, defined as $Av = \frac{D^2}{j} \cdot I_{max}(P, T) Y_{max}^3(P, T)$.
- \(Y'(s)\): Dimensionless growth rate.
- \(I'(s)\): Dimensionless nucleation rate.

Note
- I use the name "X_array" for the original array of [$\tilde{V}$, S, D, N], and the name "X_array_nd" for the nondimensionalized array ". This usage is consistent managed in all the code blocks.

## Plot the summary of nucleation rate, growth rate and saturation time

In [None]:
plot_analysis = False

if plot_analysis:

    # parameters for kinetics 
    nucleation_type = 1 # 0 - volumetric; 1 - interface
    d0 = 1e-2 # m, grain size

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # initiate the kinetics class 
    _constants, _ = Meta.get_kinetic_constants(nucleation_type)
    pTKinetics = Meta.PTKinetics(_constants)

    # directory to save results
    o_dir = os.path.join(root_path, results_dir, "plot_analysis")

    if not os.path.isdir(o_dir):
        os.mkdir(o_dir)
    
    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 2.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (11.0, 16.0)
    x_lim3 = (0.4, 1.1)
    x_lim5 = (2.0, 4.0)
    x_tick_interval = 1.0   # tick interval along x
    x_tick_interval3 = 0.1   # tick interval along x
    x_tick_interval5 = 0.5
    y_lim = (-20.0, 30.0)
    y_lim2 = (-15.0, -5.0)
    y_lim3 = (-20.0, 30.0)
    y_lim4 = (-15.0, -5.0)
    y_lim5 = (-20.0, 30.0)
    y_lim6 = (-15.0, -5.0)
    y_tick_interval = 5.0  # tick interval along y
    y_tick_interval2 = 1.0  # tick interval along y
    y_tick_interval3 = 5.0
    y_tick_interval4 = 1.0
    y_tick_interval5 = 5.0
    y_tick_interval6 = 1.0
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Constants:
    #   T: constant temperature in rates vs P plot
    #   P: constant pressure in rates vs T plot
    #   Coh: constant Coh
    # T = 9.8426e+02; P = 1.1053e10
    T = 1173.2; P = 14.0e9
    Coh = 150.0  # Concentration of water in wt.ppm H2O

    kappa = 1e-6 # m^2/s
    D = 100e3 # m
    d0 = 0.01 # m, grain size of olivine, assume this to convert to volumetric rate.

    # Visualization of growth rate variations
    fig, axes = plt.subplots(1, 3, tight_layout=True, figsize=(3*8*scaling_factor, 6*scaling_factor))

    # Plot nondimentional time vs Pressure
    ax1 = axes[0]

    Ps = np.arange(10e9, 17e9, 0.1e9)  # Pressure range in Pascals
    Ts = np.full(Ps.shape, T)
    Ps_eq = PT410["P"] + (Ts - PT410["T"])*PT410["cl"]
    Ts_eq = PT410["T"] + (Ps - PT410["P"])/PT410["cl"]
    # Ps_eq = np.full(Ps.shape, PT410["P"])

    nucleation_rate = np.zeros(Ps.shape) # derive nucleation rate
    mask = Ps > Ps_eq

    if pTKinetics.nucleation_type == 0:
        nucleation_rate[mask] = pTKinetics.nucleation_rate(Ps[mask], Ts[mask], Ps_eq[mask], Ts_eq[mask])
    elif pTKinetics.nucleation_type == 1:
        nucleation_rate[mask] = 6.7/d0 * pTKinetics.nucleation_rate(Ps[mask], Ts[mask], Ps_eq[mask], Ts_eq[mask])
    else:
        raise NotImplementedError("Value of nucleation type needs to be 0 or 1.")

    growth_rate = np.zeros(Ps.shape)
    growth_rate[mask] = pTKinetics.growth_rate(Ps[mask], Ts[mask], Ps_eq[mask], Ts_eq[mask], Coh)  # calculate growth rate

    sigma_s = Meta.calculate_sigma_s(nucleation_rate, growth_rate, d0)

    ax1.plot(Ps / 1e9, np.log10(nucleation_rate), label="Nucleation Rate", color=default_colors[0]) # Plot nucleation rate on ax1
    ax1.set_xlabel("Pressure (GPa)")
    ax1.set_ylabel(r"log10($I_v$) ($m^{-3}s^{-1}$)", color=default_colors[0])
    ax1.tick_params(axis='y', labelcolor=default_colors[0])
    ax1.set_xlim(x_lim)
    ax1.set_ylim(y_lim)

    P_eq = PT410["P"] + (T - PT410["T"])*PT410["cl"]
    ax1.axvline(x=P_eq / 1e9, color="black", linestyle="--", label=r"$P_{410}$") # Add a vertical line at PT410["P"]

    ax2 = ax1.twinx() # Create a secondary x-axis for growth rate
    ax2.plot(Ps / 1e9, np.log10(growth_rate), label="Growth Rate (log)", color=default_colors[1])
    ax2.set_ylabel("log10(Y) (m/s)", color=default_colors[1])
    ax2.tick_params(axis='y', labelcolor=default_colors[1])
    ax2.set_ylim(y_lim2)

    ax1.set_title("T = %.1f K, Coh = %.1f ppm" % (T, Coh))

    ax1.grid()
    

    # Plot rates vs T
    ax3 = axes[1]
    
    T_invert = np.arange(0.4, 1.1, 0.01)  # 1000/T range
    Ts_1 = 1000.0 / T_invert  # Temperature in Kelvin
    P_array = np.ones(Ts_1.shape) * P # make an array before passing to function
    
    Ps_eq_1 = PT410["P"] + (Ts_1 - PT410["T"])*PT410["cl"]
    Ts_eq_1 = PT410["T"] + (P_array - PT410["P"])/PT410["cl"]

    nucleation_rate_1 = np.zeros(Ts_1.shape)

    mask = (P_array > Ps_eq_1) # compute mask before passing to function

    if pTKinetics.nucleation_type == 0:
        nucleation_rate_1[mask] = pTKinetics.nucleation_rate(P_array[mask], Ts_1[mask], Ps_eq_1[mask], Ts_eq_1[mask])
    elif pTKinetics.nucleation_type == 1:
        nucleation_rate_1[mask] = 6.7/d0 * pTKinetics.nucleation_rate(P_array[mask], Ts_1[mask], Ps_eq_1[mask], Ts_eq_1[mask])
    else:
        raise NotImplementedError("Value of nucleation type needs to be 0 or 1.")
    
    ax3.plot(T_invert, np.log10(nucleation_rate_1), color=default_colors[0])
    ax3.set_xlim(x_lim3)
    ax3.set_ylim(y_lim3)
    ax3.set_xlabel("1000/T (K)")
    ax3.set_ylabel(r"log10($I_v$) ($m^{-3}s^{-1}$)", color=default_colors[0])
    ax3.tick_params(axis='y', labelcolor=default_colors[0])

    # ax3.set_ylim(y_lim3)

    T_eq_1 = PT410["T"] + (P - PT410["P"])/PT410["cl"]
    ax3.axvline(x=1000.0 / T_eq_1, color="black", linestyle="--", label=r"$T_{410}$") # Add a vertical line at PT410["P"]

    ax4 = ax3.twinx() # Plot growth rate vs Temperature
    
    growth_rate_1 = np.zeros(Ts_1.shape)
    growth_rate_1[mask] = pTKinetics.growth_rate(P_array[mask], Ts_1[mask], Ps_eq_1[mask], Ts_eq_1[mask], Coh)
    ax4.plot(T_invert, np.log10(growth_rate_1), color=default_colors[1])
    ax4.set_ylabel("log10(Y) (m/s)", color=default_colors[1])
    ax4.set_ylim(y_lim4)
    ax4.tick_params(axis='y', labelcolor=default_colors[1])
    
    ax3.grid()

    ax3.set_title("P = %.2f GPa, Coh = %.1f ppm" % (P/1e9, Coh))

    # Plot rates vs Coh
    ax5 = axes[2]
    
    log10_Cohs = np.arange(2, 4, 0.05)  # Logarithmic OH content range
    Cohs = 10**log10_Cohs  # OH content in wt.ppm H2O

    Ts_2 = np.full(Cohs.shape, T)
    Ps_2 = np.full(Cohs.shape, P)
    Ps_eq_2 = np.full(Cohs.shape, PT410["P"] + (T - PT410["T"])*PT410["cl"]) 
    Ts_eq_2 = np.full(Cohs.shape, PT410["T"] + (P - PT410["P"])/PT410["cl"])
    
    growth_rate_2 = pTKinetics.growth_rate(Ps_2, Ts_2, Ps_eq_2, Ts_eq_2, Cohs)

    if pTKinetics.nucleation_type == 0:
        nucleation_rate_2 = pTKinetics.nucleation_rate(Ps_2, Ts_2, Ps_eq_2, Ts_eq_2)
    elif pTKinetics.nucleation_type == 1:
        nucleation_rate_2 = 6.7/d0 * pTKinetics.nucleation_rate(Ps_2, Ts_2, Ps_eq_2, Ts_eq_2)
    else:
        raise NotImplementedError("Value of nucleation type needs to be 0 or 1.")

    # print("nucleation_rate_2: ", nucleation_rate_2) # debug
    
    ax5.plot(np.log10(Cohs), np.log10(nucleation_rate_2), color=default_colors[0])
    ax5.set_xlim(x_lim5)
    ax5.set_ylim(y_lim5)
    ax5.set_xlabel(r"log10($C_{OH}$) (ppm H / Si)") 
    ax5.set_ylabel(r"log10($I_v$) ($m^{-3}s^{-1}$)", color=default_colors[0])
    ax5.tick_params(axis='y', labelcolor=default_colors[0])
    
    ax6 = ax5.twinx() # Plot growth rate vs Temperature
    ax6.plot(np.log10(Cohs), np.log10(growth_rate_2), color=default_colors[1])
    ax6.set_ylabel(r"log10(Y) (m/s)", color=default_colors[1])
    ax6.tick_params(axis='y', labelcolor=default_colors[1])
    ax6.set_ylim(y_lim6)

    ax5.set_title("T = %.1f K, P = %.2f GPa" % (T, P/1e9))

    ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval)) # set ticks
    ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))
    
    ax2.yaxis.set_major_locator(MultipleLocator(y_tick_interval2))
    ax2.yaxis.set_minor_locator(MultipleLocator(y_tick_interval2/(n_minor_ticks+1)))
    
    ax3.xaxis.set_major_locator(MultipleLocator(x_tick_interval3))
    ax3.xaxis.set_minor_locator(MultipleLocator(x_tick_interval3/(n_minor_ticks+1)))
    ax3.yaxis.set_major_locator(MultipleLocator(y_tick_interval3))
    ax3.yaxis.set_minor_locator(MultipleLocator(y_tick_interval3/(n_minor_ticks+1)))
    
    ax4.yaxis.set_major_locator(MultipleLocator(y_tick_interval4))
    ax4.yaxis.set_minor_locator(MultipleLocator(y_tick_interval4/(n_minor_ticks+1)))
    
    ax5.xaxis.set_major_locator(MultipleLocator(x_tick_interval5)) # set ticks
    ax5.xaxis.set_minor_locator(MultipleLocator(x_tick_interval5/(n_minor_ticks+1)))
    ax5.yaxis.set_major_locator(MultipleLocator(y_tick_interval5))
    ax5.yaxis.set_minor_locator(MultipleLocator(y_tick_interval5/(n_minor_ticks+1)))
    
    ax6.yaxis.set_major_locator(MultipleLocator(y_tick_interval6))
    ax6.yaxis.set_minor_locator(MultipleLocator(y_tick_interval6/(n_minor_ticks+1)))

    ax5.grid()

    for spine in ax.spines.values(): 
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    # save outputs
    file_out = os.path.join(o_dir, "Yoshioka_2015_rates_P_%.1fGPa_T_%.1fK_Coh_%.1f_nu_%d.pdf" % (P/1e9, T, Coh, pTKinetics.nucleation_type))
    fig.savefig(file_out)
    print("Saved figure: %s" % file_out)

    # Reset rcParams to defaults

    rcdefaults()

In [None]:
if plot_analysis:

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 2.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    
    x_lim = (11.0, 16.0)
    x_lim3 = (0.5, 1.1)
    x_lim5 = (2.0, 4.0)
    y_lim = (-20.0, 20.0)
    y_lim1_1 = (-1.0, 1.0)
    y_lim3 = (-20.0, 20.0)
    y_lim3_1 = (-5.0, 5.0)
    y_lim5 = (-16.0, -11.0)
    y_lim5_1 = (-5.0, 5.0)
    x_tick_interval = 1.0   # tick interval along x
    x_tick_interval3 = 0.1   # tick interval along x
    y_tick_interval = 10.0
    y_tick_interval1_1 = 0.5
    y_tick_interval3 = 10.0
    y_tick_interval3_1 = 2.5
    y_tick_interval5 = 1.0
    y_tick_interval5_1 = 2.5
    
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })


    # initiate figure
    fig, axes = plt.subplots(1, 3, tight_layout=True, figsize=(3*8*scaling_factor, 6*scaling_factor))
    

    # Plot nondimentional time vs Pressure
    ax1 = axes[0] 
    ax1_1 = ax1.twinx()
    sigma_s = Meta.calculate_sigma_s(nucleation_rate, growth_rate, d0, kappa=kappa, D=D)
    t_g = d0 / 6.7 / growth_rate

    ax1.plot(Ps / 1e9, np.log10(sigma_s*D*2.0/kappa/year), color=default_colors[0])
    ax1_1.plot(Ps / 1e9, np.log10(t_g/year), color=default_colors[1])
    
    ax1.axvline(x=P_eq / 1e9, color="black", linestyle="--", label=r"$P_{eq}$") # Add a vertical line at PT410["P"]

    ax1.grid()

    ax1.set_xlim(x_lim)
    ax1.set_ylim(y_lim)
    
    # ax1_1.set_ylim(y_lim1_1)
    
    ax1.set_xlabel("Pressure (GPa)")
    ax1.set_ylabel("log (Saturation Time) (year)", color=default_colors[0])
    ax1_1.set_ylabel("log (Growth Time) (year)", color=default_colors[1])

    ax1.tick_params(axis='y', labelcolor=default_colors[0])
    ax1_1.tick_params(axis='y', labelcolor=default_colors[1])

    ax1.set_title("T = %.1f K, Coh = %.1f ppm" % (T, Coh))
    
    
    # Plot nondimentional time vs Temperature
    
    ax3 = axes[1] 
    ax3_1 = ax3.twinx()

    sigma_s_1 = Meta.calculate_sigma_s(nucleation_rate_1, growth_rate_1, d0, kappa=kappa, D=D)
    t_g_1 = d0 / 6.7 / growth_rate_1
    
    ax3.plot(T_invert, np.log10(sigma_s_1*D*2.0/kappa/year), color=default_colors[0])
    ax3_1.plot(T_invert, np.log10(t_g_1/year), color=default_colors[1])
    ax3.axvline(x=1000.0 / T_eq_1, color="black", linestyle="--", label=r"$T_{eq}$") # Add a vertical line at PT410["P"]
    
    ax3.set_xlim(x_lim3)
    ax3.set_ylim(y_lim3)
    ax3_1.set_ylim(y_lim3_1)

    ax3.grid()
    
    ax3.set_xlabel("1000/T (K)")
    ax3.set_ylabel("log (Saturation Time) (year)", color=default_colors[0])
    ax3_1.set_ylabel("log (Growth Time) (year)", color=default_colors[1])
    
    ax3.tick_params(axis='y', labelcolor=default_colors[0])
    ax3_1.tick_params(axis='y', labelcolor=default_colors[1])
    
    ax3.set_title("P = %.2f GPa, Coh = %.1f ppm" % (P/1e9, Coh))

    # Plot nondimentional time vs Coh
    
    ax5 = axes[2] 
    ax5_1 = ax5.twinx()
    
    sigma_s_2 = Meta.calculate_sigma_s(nucleation_rate_2, growth_rate_2, d0, kappa=kappa, D=D)
    t_g_2 = d0 / 6.7 / growth_rate_2

    ax5.plot(np.log10(Cohs), np.log10(sigma_s_2*D*2.0/kappa/year), color=default_colors[0])
    ax5_1.plot(np.log10(Cohs), np.log10(t_g_2/year), color=default_colors[1])
    
    ax5.set_xlabel("log10(OH content) (wt. ppm H2O)") 
    ax5.set_ylabel("log (Saturation Time) (year)", color=default_colors[0])
    ax5_1.set_ylabel("log (Growth Time) (year)", color=default_colors[1])
    
    ax5.set_title("T = %.1f K, P = %.2f GPa" % (T, P/1e9))
    
    ax5.set_xlim(x_lim5)
    ax5.set_ylim(y_lim5)
    ax5_1.set_ylim(y_lim5_1)
    
    ax5.tick_params(axis='y', labelcolor=default_colors[0])
    ax5_1.tick_params(axis='y', labelcolor=default_colors[1])
    
    ax5.grid()
    
    # ax3.set_title("P = %.1f GPa, T = %.1f K" % (P / 1e9, T))


    # ax3.semilogy(Ps/1e9, sigma_s)
    # ax3.set_xlabel("Pressure (GPa)")
    # ax3.set_ylabel("Nondimentional Saturation Time")
    # ax3.grid()
    # ax3.set_xlim([13.0, 16.0])
    
    # ax3.axvline(x=PT410["P"] / 1e9, color="black", linestyle="--", label=r"$P_{410}$") # Add a vertical line at PT410["P"]

    ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval)) # set ticks
    ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))
    
    ax1_1.yaxis.set_major_locator(MultipleLocator(y_tick_interval1_1))
    ax1_1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval1_1/(n_minor_ticks+1)))
    
    ax3.xaxis.set_major_locator(MultipleLocator(x_tick_interval3)) # set ticks
    ax3.xaxis.set_minor_locator(MultipleLocator(x_tick_interval3/(n_minor_ticks+1)))
    ax3.yaxis.set_major_locator(MultipleLocator(y_tick_interval3))
    ax3.yaxis.set_minor_locator(MultipleLocator(y_tick_interval3/(n_minor_ticks+1)))
    
    ax3_1.yaxis.set_major_locator(MultipleLocator(y_tick_interval3_1))
    ax3_1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval3_1/(n_minor_ticks+1)))
    
    ax5.xaxis.set_major_locator(MultipleLocator(x_tick_interval5)) # set ticks
    ax5.xaxis.set_minor_locator(MultipleLocator(x_tick_interval5/(n_minor_ticks+1)))
    ax5.yaxis.set_major_locator(MultipleLocator(y_tick_interval5))
    ax5.yaxis.set_minor_locator(MultipleLocator(y_tick_interval5/(n_minor_ticks+1)))
    
    ax5_1.yaxis.set_major_locator(MultipleLocator(y_tick_interval5_1))
    ax5_1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval5_1/(n_minor_ticks+1)))


    for spine in ax.spines.values(): 
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    # save outputs
    file_out = os.path.join(o_dir, "Yoshioka_2015_summary_P_%.1fGPa_T_%.1fK.pdf" % (P/1e9, T))
    fig.savefig(file_out)
    print("Saved figure: %s" % file_out)

    # Reset rcParams to defaults

    rcdefaults()

## Solve the kinetics at specific P, T conditions

All variables $X_0$, $X_1$, $X_2$, $X_3$ are advected with particles in our numerical simulation using ASPECT. The differential equations are solved with 4-th order Runge-Kutta method at every timestep to update the metastable kinetics.

Here we test the solution with prescribed P, T conditions.

### Solve the kinetics for a (P, T) condition

In [None]:
# Tips: look at the class definiition in file hamageolib/research/haoyuan_2d_subduction/metastable.py
is_solving_point = False

if is_solving_point:
    
    Coh = 150.0 # ppm H2O
    d0 = 1e-2 # m, initial grain size
    Peq = 13.5e9
    Teq = 1740.0
    Cl = 2e6
    
    Ps = [13.5e9 for i in range(3)]
    Ts = [873.15, 973.15, 1573.15]

    nucleation_type = 1 # 0 - volumetric; 1 - surface

    # initiate the kinetics class 
    _, _constants1 = Meta.get_kinetic_constants(nucleation_type)
    Mo_Kinetics = Meta.MO_KINETICS(_constants1, post_process=["ts", "tg"])
    Mo_Kinetics.set_initial_grain_size(d0)

    # set P T condition for solution
    Mo_Kinetics.set_PT_eq(Peq, Teq, Cl)
    Mo_Kinetics.link_and_set_kinetics_model(Meta.PTKinetics)

    print("Equilirbium T at %.3f" % Meta.compute_eq_T(PT410, Ps[0]))
    print("Compute kinetics at Ts:", Ts)
    
    # Parameters for solver
    t_max = 10e6 * year # s
    n_t = 100
    n_span = 10

    result_array = []
    for i in range(len(Ps)):
        P = Ps[i]
        T = Ts[i]

        Mo_Kinetics.set_kinetics_fixed(P, T, Coh)

        # solve the kinetics
        with Mute():
            results = Mo_Kinetics.solve(P, T, 0.0, t_max, n_t, n_span, debug=True)        

        # export result
        txt_file_path = os.path.join(results_dir, "solution_P%.2fGPa_T%.2fK.txt" % (P/1e9, T))
        with open(txt_file_path, 'w') as fout:
            np.savetxt(txt_file_path, results)
        print("Solution saved to %s" % txt_file_path)

        result_array.append(results)

#### For one of the condition, Plot the whole kinetics

#### For all the different conditions, Plot the volume of tranferred materiala

We choose T = 873.15, as the transition is rate-controlled in this condition

In [None]:
if is_solving_point:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # index of condition
    idx = 1

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 1.0
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 1000) # year
    x_tick_interval = 2e3   # tick interval along x
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })


    P = Ps[idx]
    T = Ts[idx]

    # parse result for one P, T condition
    results = result_array[idx]
    time = results[:, 0]
    total_grain_number = results[:, 1]
    total_grain_diameter = results[:, 2]
    volume = results[:, 5]
    extended_volume = results[:, 4]
    is_saturated_a = results[:, 6]
    time_saturated_a = results[:, 7]

    # Derive an average grain size
    average_grain_diameter = 2.0 * (volume / total_grain_diameter / (4.0/3.0 * np.pi))**(1.0/3.0)
    
    # find condition at saturation
    total_grain_number_saturated_value = Mo_Kinetics.X_saturated[0]
    total_grain_diameter_saturated_value = Mo_Kinetics.X_saturated[1]
    extended_volume_saturated_value = Mo_Kinetics.X_saturated[3]

    average_grain_diameter_saturated_value = total_grain_diameter_saturated_value / total_grain_number_saturated_value

    # find the critical radius
    critical_radius_value = Mo_Kinetics.compute_rc(0)

    # find where cite situation 
    indices = np.where(is_saturated_a == 1.0)[0]
    time_saturated_value = time_saturated_a[indices[0]]
    extended_volume_critial_saturated_value = 4 * np.pi / 3 * critical_radius_value**3.0 * total_grain_number_saturated_value
    
    print("P = %.4e GPa, T = %.4e K, ts: %.4e year, d: %.4e m, rc: %.4e m, extended volume: %.4e, extended volume by critical radius: %.4e"\
           % (P/1e9, T, time_saturated_value/year, average_grain_diameter_saturated_value, critical_radius_value, extended_volume_saturated_value, extended_volume_critial_saturated_value))
    
    # Start plot
    fig = plt.figure(figsize=(2.1*8*scaling_factor, 2.1*5*scaling_factor), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)

    # plot the grain number and total diameter
    ax = fig.add_subplot(gs[0, 0])
    ax.semilogy(time/year/1e3, total_grain_number, color=default_colors[idx], label='Grain Number, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
    ax.set_xlabel('Time (kyr)')
    ax.set_ylabel('Grain Number (m^-3)')
    ax.tick_params(axis='y')
    
    ax.set_xlim(x_lim)
    
    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    
    ax1 = ax.twinx()
    ax1.semilogy(time/year/1e3, total_grain_diameter, color=default_colors[idx], label='Grain Diameter, P=%.2f Gpa, T=%.2f K' %(P/1e9, T), linestyle="--")
    ax1.set_ylabel('Total Grain Diameter (m^-2)')

    # plot the average grain size
    ax = fig.add_subplot(gs[1, 0])
    ax.plot(time/year/1e3, np.log10(average_grain_diameter), color=default_colors[idx], label='Average Grain Size, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
    ax.hlines(np.log10(average_grain_diameter_saturated_value), x_lim[0], x_lim[1], label='(Saturation)', color=default_colors[idx], linestyle="--")
    ax.set_xlabel('Time (kyr)')
    ax.set_ylabel('log10(Grain Size (m))')
    ax.tick_params(axis='y')
    
    ax.set_xlim(x_lim)
    
    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))

    ax.legend()

    # plot the volume and the extended volume
    ax = fig.add_subplot(gs[0, 1])
    ax.plot(time/year/1e3, volume, color=default_colors[idx], label='Volume, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
    ax.set_xlabel('Time (kyr)')
    ax.set_ylabel('Volume')
    ax.tick_params(axis='y')

    ax.set_xlim(x_lim)
    ax.set_ylim((0.0, 1.0))

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(0.2))
    ax.yaxis.set_minor_locator(MultipleLocator(0.2/(n_minor_ticks+1)))
    
    ax.grid()

    ax1 = ax.twinx()
    ax1.plot(time/year/1e3, extended_volume, linestyle="--", color=default_colors[idx], label="Extended Volume")
    ax1.set_ylim((0.0, 5.0))
    ax1.yaxis.set_major_locator(MultipleLocator(1.0))
    ax1.yaxis.set_minor_locator(MultipleLocator(1.0/(n_minor_ticks+1)))
    ax1.set_ylabel('Extended Volume')

    handles1, labels1 = ax.get_legend_handles_labels()
    handles2, labels2 = ax1.get_legend_handles_labels()
    ax.legend(handles1 + handles2, labels1 + labels2, loc="best")


    # Reset rcParams to defaults
    rcdefaults()

In [None]:
if is_solving_point:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 2.0 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 1e4)
    y_lim1 = (0.0, 1.0)
    y_lim2 = (-10.0, 10.0)
    x_tick_interval = 2e3   # tick interval along x
    y_tick_interval1 = 0.25  # tick interval along y
    y_tick_interval2 = 5.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    }) 
        
    fig = plt.figure(figsize=(16*scaling_factor, 10*scaling_factor), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)
    ax1 = fig.add_subplot(gs[0, 0]) # V
    ax2 = ax1.twinx() # Extended V
    ax3 = fig.add_subplot(gs[1, 0]) # grain size
    ax4 = fig.add_subplot(gs[0, 1]) # V and grain size
    ax5 = ax4.twinx() # grain size
    
    for i in range(len(Ps)):

        P = Ps[i]
        T = Ts[i]

        # parse result for one P, T condition
        results = result_array[i]
        time = results[:, 0]
        total_grain_number = results[:, 1]
        total_grain_diameter = results[:, 2]
        volume = results[:, 5]
        extended_volume = results[:, 4]
        is_saturated_a = results[:, 6]
        time_saturated_a = results[:, 7]

        print("P: ", P)
        print("T: ", T)
        print("total_grain_number: ", total_grain_number)
        
        # find_average_grain_diameter
        total_grain_number_saturated_value = Mo_Kinetics.X_saturated[0]
        total_grain_diameter_saturated_value = Mo_Kinetics.X_saturated[1]
        extended_volume_saturated_value = Mo_Kinetics.X_saturated[3]

        # find_average_grain_diameter
        average_grain_diameter_saturated_value = total_grain_diameter_saturated_value / total_grain_number_saturated_value
        average_grain_diameter = ((volume / total_grain_number) * 6.0 / np.pi)**(1.0/3.0)

        # find the critical radius
        critical_radius_value = Mo_Kinetics.compute_rc(0)

        # find where cite situation 
        indices = np.where(is_saturated_a == 1.0)[0]
        try:
            time_saturated_value = time_saturated_a[indices[0]]
        except IndexError:
            time_saturated_value = None

        if time_saturated_value is not None:
            extended_volume_critial_saturated_value = 4 * np.pi / 3 * critical_radius_value**3.0 * total_grain_number_saturated_value
        else:
            extended_volume_critial_saturated_value = None

        def fmt(val, scale=1.0):
            return "None" if val is None else f"{val/scale:.4e}"

        print(
            "P = %s GPa, T = %s K, ts: %s year, d: %s m, rc: %s m, "
            "extended volume: %s, extended volume by critical radius: %s"
            % (
                fmt(P, 1e9),
                fmt(T),
                fmt(time_saturated_value, year),
                fmt(average_grain_diameter_saturated_value),
                fmt(critical_radius_value),
                fmt(extended_volume_saturated_value),
                fmt(extended_volume_critial_saturated_value),
            )
        )

        # Plot on the left y-axis
        ax1.plot(time/year/1e3, volume, color=default_colors[i], linewidth=4, label='Volume, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
        ax1.set_xlabel('Time (kyr)')
        ax1.set_ylabel('Volume', color='tab:blue')
        ax1.tick_params(axis='y', labelcolor='tab:blue')

        # Create twin axis for right y-axis
        if i == 0:
            _label = 'Extended Volume'
        else:
            _label = None

        # ax2.plot(time/year/1e3, extended_volume, color=default_colors[i], label=_label)
        ax2.plot(time/year/1e3, np.log10(extended_volume), color=default_colors[i], label=_label)
        ax2.set_ylabel('log(Extended Volume)', color='tab:red')
        ax2.tick_params(axis='y', labelcolor='tab:red')

        # Optional: Add grid and title
        fig.suptitle('Volume and Extended Volume vs Time')
        ax1.grid(True)

        # Optional: Combine legends
        lines_1, labels_1 = ax1.get_legend_handles_labels()
        lines_2, labels_2 = ax2.get_legend_handles_labels()
        ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper left')

        # Adjust spine thickness for this plot
        for spine in ax1.spines.values():
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
        for spine in ax2.spines.values():
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
        
        # average grain size 
        ax3.plot(time/year/1e3, np.log10(average_grain_diameter), color=default_colors[i], linewidth=4, label='Dav, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
        ax3.set_xlabel('Time (kyr)')
        ax3.set_ylabel('log10(Average Grain Size)')

        # V and grain size
        ax4.plot(time/year/1e3, volume, color=default_colors[i], linewidth=4, label='Volume, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
        ax4.set_xlabel('Time (kyr)')
        ax4.set_ylabel('Volume')
        ax4.tick_params(axis='y')
        ax5.plot(time/year/1e3, np.log10(average_grain_diameter), color=default_colors[i], linestyle="--", linewidth=4, label='Dav, P=%.2f Gpa, T=%.2f K' %(P/1e9, T))
        ax5.set_ylabel('log10(Average Grain Size)')
        
        lines_4, labels_4 = ax4.get_legend_handles_labels()
        lines_5, labels_5 = ax5.get_legend_handles_labels()
        ax4.legend(lines_4 + lines_5, labels_4 + labels_5)

    # axis configuration
    ax1.set_xlim(x_lim)
    ax1.set_ylim(y_lim1)
    ax2.set_ylim(y_lim2)
    ax3.set_xlim(x_lim)
    ax4.set_xlim([0, 1000.0])
    ax4.set_ylim([0.0, 1.05])
    ax5.set_ylim([-10.0, 0.5])

    ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval1))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval1/(n_minor_ticks+1)))
    
    ax2.yaxis.set_major_locator(MultipleLocator(y_tick_interval2))
    ax2.yaxis.set_minor_locator(MultipleLocator(y_tick_interval2/(n_minor_ticks+1)))
    
    ax3.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax3.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax3.yaxis.set_major_locator(MultipleLocator(2.5))
    ax3.yaxis.set_minor_locator(MultipleLocator(2.5/(n_minor_ticks+1)))
    
    ax4.xaxis.set_major_locator(MultipleLocator(200.0))
    ax4.xaxis.set_minor_locator(MultipleLocator(200.0/(n_minor_ticks+1)))
    ax4.yaxis.set_major_locator(MultipleLocator(0.2))
    ax4.yaxis.set_minor_locator(MultipleLocator(0.2/(n_minor_ticks+1)))
    
    ax5.yaxis.set_major_locator(MultipleLocator(2.0))
    ax5.yaxis.set_minor_locator(MultipleLocator(2.0/(n_minor_ticks+1)))

    ax4.grid()

    # maintain a tight layout     
    plt.tight_layout()

    # show figure 
    plt.show()

    # save figure
    fig_path = os.path.join(results_dir, "metastable_illustration_nu_%d_d0_%2e_Coh_%.2e.pdf" % (nucleation_type, d0, Coh))
    fig.savefig(fig_path)
    print("Save figure %s" % fig_path)

    # Reset rcParams to defaults
    rcdefaults()


### Solve the kinetics for a grid

(Supplementary Material: description related to the diagram)

We calculate synthetic diagrams of kinetic conditions under the simplifying assumption that latent heat effects are neglected.
The volumetric nucleation rate and the grain growth rate are investigated within their respective ranges of dynamic significance.

We then plot contours of the site saturation timescale $t_s$ and the growth timescale $t_g$ (i.e $t_{0.5}$ assuming early site saturation). The diagram is divided into four categories based on kinetic regimes: M — metastable, G — growth-controlled kinetics, N — nucleation-controlled kinetics, and E — equilibrium transition, using the following criteria:

$$ M: \quad t_s > 100~\mathrm{Ma},\ t_g > 100~\mathrm{Ma} $$
$$ G: \quad t_s \leq t_g,\ \min\left(t_s, t_g\right) < 100~\mathrm{Ma} $$
$$ N: \quad t_s > t_g,\ \min\left(t_s, t_g\right) < 100~\mathrm{Ma} $$
$$ E: \quad t_s < 10~\mathrm{kyr},\ t_g < 10~\mathrm{kyr} $$

Contours of the Avrami number equal to 1 are also plotted.
These contours illustrate the balance between the diffusion timescale and the kinetic transition timescale.
The two different boundary thicknesses represent the characteristic thermal diffusion length scales: 100~km for the entire lithospheric plate (black dashed line) and 5~km for the effective boundary of the metastable olivine wedge (grey dashed line).

There are two types of process:
* serial: this runs with just one process. It has the advantage of reporting the progress. It could be used to estimate the running speed.
* parallel: this runs in parallel to save time.

In [None]:
import multiprocessing
from joblib import Parallel, delayed
import time


is_kinetics_diagram = False  # if we solve new dataset and merge to old ones

if is_kinetics_diagram:
    # Running options
    is_solving_kinetics_diagram = True  # if we solve new dataset and merge to old ones
    is_solving_kinetics_diagram_parallel = True # Solve results in python parallel
    is_read_dataset = True  # Read old dataset
    is_update_dataset = True # Update old dataset with new values
    full_mesh_PT = False # Use the full ranges of P, T to creat mesh

    # Equilibrium values
    Peq = 13.5e9
    Teq = 1740.0
    Cl = 2e6
    
    # Kinetic values
    Coh = 150.0 # ppm H2O
    d0 = 1e-2 # m, grain size
    nucleation_type = 1

    # Timesteps and resolution
    t_max = 1e6 * year
    # t_max = 10 * 1e6 * year # for plotting figure in the supplementary material

    # Range of P, T and resolution
    if full_mesh_PT:
        P_min = 0.0; P_max = 20e9
        T_min = 673.15; T_max = 1873.15 # K
    else:
        # P_min = 12e9; P_max = 14e9; T_min = 800.0; T_max = 1873.15 # Pa, K
        P_min = 11e9; P_max = 20e9; T_min = 900.0; T_max = 1100.0
        # prescribe values by user options
        pass
    N_P = 201; N_T = 101 # resolution of P, T

    # Numerical constants
    n_t = 10; n_span = 20  # resolution of time


    # Initiiation
    # Note: Pressure in Pascals, Temperature in Kelvin, Time in seconds
    # Create a meshgrid
    N_t = n_t * n_span + 1 # total number in t dimension

    P_values = np.linspace(P_min, P_max, N_P)  # global mesh
    T_values = np.linspace(T_min, T_max, N_T)
    
    t_values = np.linspace(0, t_max, n_t*n_span+1)     

    T_grid, P_grid, t_grid = np.meshgrid(T_values, P_values, t_values, indexing="ij")
    V_grid = np.zeros(P_grid.shape)

#### Solve

In [None]:
# initiate the class
if is_kinetics_diagram and is_solving_kinetics_diagram:
    _, _constants1 = Meta.get_kinetic_constants(nucleation_type)
    Mo_Kinetics = Meta.MO_KINETICS(_constants1, post_process=["ts", "tg"])
    Mo_Kinetics.set_initial_grain_size(d0)

    Mo_Kinetics.set_PT_eq(Peq, Teq, Cl)
    Mo_Kinetics.link_and_set_kinetics_model(Meta.PTKinetics)

    # Function to solve for a given T, P
    def solve_metastable_kinetics(P, T, Coh, t_max, n_t, n_span, Mo_Kinetics):
        # if P < P_eq:
        #     return np.zeros(n_t * n_span, 7)
        Mo_Kinetics.set_kinetics_fixed(P, T, Coh)

        t_values = np.linspace(0, t_max, n_t*n_span+1)
        # compute the nucleation and growth rates
        Iv_values = np.zeros(n_t*n_span+1)
        Y_values = np.zeros(n_t*n_span+1)
        ts_values = np.zeros(n_t*n_span+1)
        tg_values = np.zeros(n_t*n_span+1)
        for i, t in enumerate(t_values): 
            Iv_values[i] = Mo_Kinetics.compute_Iv(t)
            Y_values[i] = Mo_Kinetics.compute_Y(t)
            ts_values[i] = Mo_Kinetics.compute_ts(t)
            tg_values[i] = Mo_Kinetics.compute_tg(t)


        # compute the Avrami number
        Av_values = np.zeros(n_t*n_span+1)
        Av_c_values = np.zeros(n_t*n_span+1)
        for i, t in enumerate(t_values): 
            Av_values[i] = Mo_Kinetics.compute_Av(t)
            Av_c_values[i] = Mo_Kinetics.compute_Av(t, D=5e3)

        # solve the ODEs
        results = Mo_Kinetics.solve(P, T, 0.0, t_max, n_t, n_span)

        # stack all results 
        results = np.hstack((results, Av_values[:, np.newaxis], Av_c_values[:, np.newaxis], Iv_values[:, np.newaxis],\
                            Y_values[:, np.newaxis], ts_values[:, np.newaxis], tg_values[:, np.newaxis]))
        return results

    # Serial or Parallelize computation
    start = time.time()

    if is_solving_kinetics_diagram_parallel:
        # Solve in parallel
        num_processes = multiprocessing.cpu_count()  # Print the number of available processes
        print(f"Number of available processes: {num_processes}")

        results_raw = Parallel(n_jobs=-1)(
            delayed(solve_metastable_kinetics)(
                P_grid[i, j, 0], T_grid[i, j, 0], Coh,t_max, n_t, n_span, Mo_Kinetics
            )
            for i in range(T_grid.shape[0]) for j in range(T_grid.shape[1])
        )

        # Convert results_raw to a structured grid
        grid_shape = (T_grid.shape[0], T_grid.shape[1])  # Grid size
        time_steps = results_raw[0].shape[0]  # Number of time steps
        num_columns = results_raw[0].shape[1]  # Number of variables

        results_array = np.array(results_raw)  # Convert list to NumPy array
        results_grid = results_array.reshape(*grid_shape, time_steps, num_columns)  # Reshape to grid

        # Access specific data
        V_grid = results_grid[:, :, :, 5]  # Extract column 0 across all grid points
    else:
        # Precompute equilibrium pressures
        P_eq_values = [Meta.compute_eq_P(Mo_Kinetics.PT_eq, T) for T in T_values]

        # Solve
        V_grid = np.zeros(P_grid.shape)
        for i in range(T_grid.shape[0]):
            for j in range(T_grid.shape[1]):
                V_array = solve_metastable_kinetics(P_grid[i, j, 0], T_grid[i, j, 0], Coh,t_max, n_t, n_span, Mo_Kinetics, P_eq_values[i])
                V_grid[i, j, :] = V_array
                sys.stdout.write("\rsolved %d / %d" % (i*T_grid.shape[1]+j, T_grid.shape[0]*T_grid.shape[1]-1)) # debug
                sys.stdout.flush()
        sys.stdout.write("\nSolve Metastable Kinetics take %.2f s" % (end-start))
        sys.stdout.flush()


    end = time.time()
    print("\nSolve Metastable Kinetics took %.2f s" % (end - start))


    # Convert to pandas object
    # P and T in the grid are flatten to a vector along with the result on volume,
    # then these are parsed to a pandas 2-d array
    data_raw = []

    for i in range(T_grid.shape[0]):
        for j in range(T_grid.shape[1]):
            # Extract the pressure and temperature for this grid point
            P_value = P_grid[i, j, 0]
            T_value = T_grid[i, j, 0]
            
            # Extract the corresponding results
            result = results_raw[i * T_grid.shape[1] + j]  # Flattened indexing
            
            # Combine P, T with each row of the results
            for k in range(result.shape[0]):  # Iterate over time steps
                row = [P_value, T_value] + result[k].tolist()  # Combine P, T, and the result row
                data_raw.append(row)

    data_array = np.array(data_raw) # Convert the list of rows into a 2D NumPy array

    # Convert to pandas object
    data_columns = Mo_Kinetics.result_columns + ["Av", "Av_c", "Iv", "Y", "ts", "tg"]
    columns = ["P", "T"] + data_columns

    new_data = pd.DataFrame(data_array, columns=columns)

Next, results are saved to a grid data. Here we first see if there is a previous file. If so, the data in the file is updated with new data points computed.

In [None]:
if is_kinetics_diagram:

    from datetime import datetime
    from shutil import copy

    # working directory and file
    current_results_dir = os.path.join(results_dir, "grid_data_09102025")
    if not os.path.isdir(current_results_dir):
        os.mkdir(current_results_dir)

    data_file = os.path.join(current_results_dir, "metastable_grid_data.parquet")
    data_file_csv = os.path.join(current_results_dir, "metastable_grid_data.csv")

    if is_read_dataset:
        assert(os.path.isfile(data_file))
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        data = pd.read_parquet(data_file)
        
        print("loaded data file %s" % data_file)

        if is_solving_kinetics_diagram:

            # Identify existing combinations of P and T in data
            existing_combinations = set(zip(data["P"], data["T"]))

            # Filter new_data to exclude rows with existing P and T combinations
            filtered_new_data = new_data[~new_data.apply(lambda row: (row["P"], row["T"]) in existing_combinations, axis=1)]

            # Merge the filtered new_data into data
            data = pd.concat([data, filtered_new_data], ignore_index=True)

    else:
        data = new_data
            
    # Save data
    # First, generate a backup
    if is_update_dataset:
        data_file_backup = os.path.join(current_results_dir, f"metastable_grid_data_{timestamp}.parquet")
        copy(data_file, data_file_backup)
        print("created backup file %s" % data_file_backup)
        
        data.to_parquet(data_file, index=False) # fail to run because of missing packages
        data.to_csv(data_file_csv, index=False)

        print("saved parquet file %s" % data_file)
        print("saved csv file %s" % data_file_csv)

#### Plot the diagram

With the interpolation scheme, we define both a near-neighbor interpolation and a inverse distance weighting (IDW) interpolation.

In [None]:
if is_kinetics_diagram:

    from scipy.interpolate import NearestNDInterpolator


    def categorize_the_diagram(ts, tg):
        '''
        categorize the diagram from known time
        Inputs:
            ts - time for site situation
            tg - time for grain growth
        '''
        # assign a limit for metastability in geodynamic timescale 
        t_eq = 1e4* year
        t_meta = 1e8* year

        # switch between float and numpy objects
        if type(ts) in [float, np.float64]:
            if ts > t_meta and tg > t_meta:
                value = 0
            elif ts <= tg:
                value = 1
            else:
                value = 2 
        elif type(ts) == np.ndarray:
            # in case of a numpy array, first assign a mask
            value = np.full(ts.shape, 0, dtype=int)
            mask0 = (ts > t_meta) & (tg > t_meta)
            mask1 = ((ts < t_meta) | (tg < t_meta)) & (ts <= tg)
            mask2 = ((ts < t_meta) | (tg < t_meta)) & (ts > tg)
            mask3 = (ts < t_eq) & (tg < t_eq)
            value[mask1] = 1
            value[mask2] = 2
            value[mask3] = 3
        else:
            return NotImplementedError()

        return value

    ## Remesh the grid from data
    N_P_1 = N_P
    N_T_1 = N_T
    P_values_1 = np.linspace(0.0, 30e9, N_P_1)  # Pressure in Pascals
    T_values_1 = np.linspace(273.15, 1873.15, N_T_1)  # Temperature in Kelvin

    T_grid_1, P_grid_1 = np.meshgrid(T_values_1, P_values_1)

    T_flat_1 = T_grid_1.flatten()
    P_flat_1 = P_grid_1.flatten()

    # compute equilibrium values
    P_eq_values_1 = Meta.compute_eq_P(Mo_Kinetics.PT_eq, T_values_1)

    # perform interpolation
    # use scaling factors for T, P, t to normalize the values
    interpolators = {}
    
    P0 = 1e9 # Pa
    T0 = 100.0 # K
    t0 = 1000 * year # s

    for col in data_columns:
        interpolators[col] = NearestNDInterpolator(
        np.column_stack((data["T"]/T0, data["P"]/P0, data["t"]/t0)),
        data[col]
    )  

1. Plot the volumetric nucleation rate, grain growth rate and the Av number

In [None]:
if is_kinetics_diagram:

    import matplotlib.colors as mcolors
    from matplotlib.ticker import MultipleLocator
    from matplotlib import rcdefaults

    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 2.0 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (400.0, 1800.0)
    x_tick_interval = 200.0   # tick interval along x
    y_lim = (10.0, 30.0)
    y_tick_interval = 5.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    idx = np.arange(0, N_P * N_T * N_t, N_t)

    # Iv_grid = data["Iv"].to_numpy()[np.ix_(idx)].reshape(T_grid_1.shape)
    # Y_grid = data["Y"].to_numpy()[np.ix_(idx)].reshape(T_grid_1.shape)
    # Av_grid = data["Av"].to_numpy()[np.ix_(idx)].reshape(T_grid_1.shape)

    # perform interpolation
    Iv_flat_1 = interpolators["Iv"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    Iv_grid = Iv_flat_1.reshape(T_grid_1.shape)
    Y_flat_1 = interpolators["Y"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    Y_grid = Y_flat_1.reshape(T_grid_1.shape)
    Av_flat_1 = interpolators["Av"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    Av_grid = Av_flat_1.reshape(T_grid_1.shape)
    Av_c_flat_1 = interpolators["Av_c"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    Av_c_grid = Av_c_flat_1.reshape(T_grid_1.shape)
    ts_flat_1 = interpolators["ts"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    ts_grid = ts_flat_1.reshape(T_grid_1.shape)
    tg_flat_1 = interpolators["tg"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, 0.0))
    tg_grid = tg_flat_1.reshape(T_grid_1.shape)

    # Create subplots
    fig = plt.figure(figsize=(14*scaling_factor, 14*scaling_factor), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)

    # Plot Iv
    ax = fig.add_subplot(gs[0, 0])
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Iv_grid), (-10, 0, 10, 20), vmin=-20, vmax=30, colors='k', linestyles="-")
    ax.clabel(contours, fmt='%d', colors='k')
    contours_1 = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Y_grid), (-14, -12, -10, -8, -6, -4), colors='c', linestyles="-")
    ax.clabel(contours_1, fmt='%d', colors='c')
    ax.plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    ax.grid()

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.invert_yaxis()

    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Pressure (GPa)")
    ax.set_title(r"$I_v$ and $Y$")

    # Plot ts, tg.
    # Then categorize the diagrame based on these values
    # Also append the controus of Av
    ax = fig.add_subplot(gs[0, 1])
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(ts_grid/year), (2, 4, 6, 8), colors='k', linestyles="-")
    ax.clabel(contours, fmt='%d', colors='k')
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(tg_grid/year), (2, 4, 6, 8), colors='c', linestyles="-")
    ax.clabel(contours, fmt='%d', colors='c')
    ax.plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    category_grid = categorize_the_diagram(ts_grid, tg_grid)
    norm = mcolors.BoundaryNorm(boundaries=[-0.5, 0.5, 1.5, 2.5, 3.5], ncolors=4)
    cmap = plt.get_cmap('Pastel1', 4)
    cmesh = ax.pcolormesh(T_grid_1, P_grid_1 / 1e9, category_grid, cmap=cmap, norm=norm, shading='auto')
    # cbar = plt.colorbar(cmesh, ax=ax, ticks=[0, 1, 2, 3])

    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Av_grid), (0), colors='k', linestyles="--")
    # ax.clabel(contours, fmt='%d', colors='k')
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Av_c_grid), (0), colors='tab:gray', linestyles="--")
    # ax.clabel(contours, fmt='%d', colors='c')

    ax.grid()

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.invert_yaxis()

    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Pressure (GPa)")
    ax.set_title(r"$t_s$ and $t_g$")

    # fig.colorbar(h4, ax=ax, label="ts")

    # Plot Av
    ax = fig.add_subplot(gs[1, 0])
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Av_grid), (-4, 0, 4), colors='k', linestyles="-")
    ax.clabel(contours, fmt='%d', colors='k')
    contours = ax.contour(T_grid_1, P_grid_1 / 1e9, np.log10(Av_c_grid), (-4, 0, 4), colors='c', linestyles="-")
    ax.clabel(contours, fmt='%d', colors='c')
    ax.plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    ax.grid()

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.invert_yaxis()

    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Pressure (GPa)")
    ax.set_title("Av")
    # fig.colorbar(h3, ax=ax, label="Av")

    fig_path = os.path.join(results_dir, "Mo_kinetics_diagram_nu_%d_Coh_%.2f_d_%.2f.pdf" % (Mo_Kinetics.Kinetics.nucleation_type, Coh, d0))
    fig.savefig(fig_path)

    print("Save figure: %s" % fig_path)

    # Reset rcParams to defaults
    rcdefaults()

2. Inspect a given time

Then, we plot at a given time

In [None]:
if is_kinetics_diagram:

    from scipy.ndimage import zoom
    from cmcrameri import cm as ccm 

    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (600.0, 1800.0)
    x_tick_interval = 200.0   # tick interval along x
    y_lim = (10.0, 20.0)
    y_tick_interval = 5.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # constant time
    # t_constant = 1e6*year
    t_constant = 1e6*year     # Time in seconds, for plotting figure in supplementary material

    # interpolate volumn
    V_flat_1 = interpolators["V"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, t_constant/t0))
    V_grid = V_flat_1.reshape(T_grid_1.shape)
    
    # interpolate N
    N_flat_1 = interpolators["N"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, t_constant/t0))
    N_grid = N_flat_1.reshape(T_grid_1.shape)

    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14*scaling_factor, 12*scaling_factor), constrained_layout=True)

    # Plot V
    h1 = axes[0, 0].pcolormesh(T_grid_1, P_grid_1 / 1e9, V_grid, cmap="viridis", shading="auto")
    axes[0, 0].contour(T_grid_1, P_grid_1 / 1e9, V_grid, (0.5, 0.99), colors=['tab:gray', 'k'], linestyles="-")
    axes[0, 0].plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    axes[0, 0].grid()

    axes[0, 0].set_xlim(x_lim)
    axes[0, 0].set_ylim(y_lim)

    axes[0, 0].xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    axes[0, 0].xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    axes[0, 0].yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    axes[0, 0].yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    axes[0, 0].invert_yaxis()

    axes[0, 0].set_xlabel("Temperature (K)")
    axes[0, 0].set_ylabel("Pressure (GPa)")
    axes[0, 0].set_title("V")
    fig.colorbar(h1, ax=axes[0, 0], label="V")

    # Plot N
    N_log_grid = np.log10(N_grid)
    N_log_grid[N_log_grid<10.0] = np.nan
    h1 = axes[0, 1].pcolormesh(T_grid_1, P_grid_1 / 1e9, N_log_grid, cmap=ccm.batlow, shading="auto", vmin=0, vmax=30)
    axes[0, 1].plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    axes[0, 1].grid()

    axes[0, 1].set_xlim(x_lim)
    axes[0, 1].set_ylim(y_lim)

    axes[0, 1].xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    axes[0, 1].xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    axes[0, 1].yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    axes[0, 1].yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    axes[0, 1].invert_yaxis()

    axes[0, 1].set_xlabel("Temperature (K)")
    axes[0, 1].set_ylabel("Pressure (GPa)")
    axes[0, 1].set_title("N")
    fig.colorbar(h1, ax=axes[0, 1], label="N")

    # Average grain size
    D_av_grid = np.full(N_grid.shape, np.nan)
    mask = (N_grid > 100)
    D_av_grid[mask] = ((V_grid[mask] / N_grid[mask]) * 6.0 / np.pi)**(1/3.0)

    h1 = axes[1, 0].pcolormesh(T_grid_1, P_grid_1 / 1e9, np.log10(D_av_grid), cmap=ccm.tokyo, shading="auto", vmin=-10, vmax=-1)
    axes[1, 0].plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    axes[1, 0].grid()

    axes[1, 0].set_xlim(x_lim)
    axes[1, 0].set_ylim(y_lim)

    axes[1, 0].xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    axes[1, 0].xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    axes[1, 0].yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    axes[1, 0].yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    axes[1, 0].invert_yaxis()

    axes[1, 0].set_xlabel("Temperature (K)")
    axes[1, 0].set_ylabel("Pressure (GPa)")
    axes[1, 0].set_title("D_av")
    fig.colorbar(h1, ax=axes[1, 0], label="D_av")

    plt.show()

    # save figure
    fig_path = os.path.join(results_dir, "PTV_nu_%d_Coh_%.2e_d_%.2e_t%.2f.pdf" % (Mo_Kinetics.Kinetics.nucleation_type, Coh, d0, t_constant/1e6/year))
    fig.savefig(fig_path)

    print("Saved figure %s" % fig_path)

    # Reset rcParams to defaults

    rcdefaults()

3. Then we handle plots of contours at different time step

In [None]:
if is_kinetics_diagram:

    from scipy.ndimage import zoom
    from matplotlib.lines import Line2D

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (400.0, 1800.0)
    x_tick_interval = 200.0   # tick interval along x
    y_lim = (10.0, 30.0)
    y_tick_interval = 5.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })


    # constant time
    ts = np.array([5e4, 1e5, 1e6]) * year

    # Create subplots
    fig, ax = plt.subplots(figsize=(7, 6), constrained_layout=True)

    ax.plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    # Initialize a list for legend entries
    legend_lines = []
    legend_labels = []

    for i, t_constant in enumerate(ts):

        # perform interpolation 1: near neighber
        V_flat_1 = interpolators["V"](T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, t_constant/t0))

        # reshape to V_grid
        V_grid = V_flat_1.reshape(T_grid_1.shape)

        # Add entry to legend list
        legend_lines.append(Line2D([0], [0], color=default_colors[i+1], linestyle="--"))
        legend_labels.append("%.1e year" % (t_constant/year))

    ax.legend(legend_lines, legend_labels)

    ax.grid()

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.invert_yaxis()

    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Pressure (GPa)")
    ax.set_title("Smoothed Grid (Upsampled)")


    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    plt.show()

    # save figure
    fig_path = os.path.join(results_dir, "PTV_contours.png")
    fig.savefig(fig_path)

    print("Saved figure %s" % fig_path)

    fig_path_pdf = os.path.join(results_dir, "PTV_contours_nu_%d_Coh_%.2e_d_%.2e.pdf" % (Mo_Kinetics.Kinetics.nucleation_type, Coh, d0))
    fig.savefig(fig_path_pdf)

    print("Saved figure %s" % fig_path_pdf)

    # Reset rcParams to defaults

    rcdefaults()

# Inspect the results from cpp code

(OneNote "My cpp script". Append the technical requirments)

To run the cpp code

Navigate to HaMaGeoLib/cpp

Create a build directory and "cd build"

    cmake ..

Run and test

Run and create a diagram

	  make metastable_diagram

    ./metastable_diagram

### Plot the diagram

In [None]:
plot_cpp_diagram_results = False

if plot_cpp_diagram_results:

    from scipy.interpolate import UnivariateSpline
    from matplotlib import pyplot as plt

    root_path = os.path.join(Path().resolve().parent.parent)
    package_path = os.path.join(root_path, "hamageolib")

    if str(package_path) not in sys.path:
        sys.path.insert(0, str(package_path))


    from utils.exception_handler import my_assert
    import utils.plot_helper as plot_helper

    base_dir = Path().resolve()

    results_dir = os.path.join(root_path, "dtemp")
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)

First we load the data file generated by the cpp script

In [None]:
if plot_cpp_diagram_results:

    file_path = "/home/lochy/ASPECT_PROJECT/HaMaGeoLib/dtemp/metastable_diagram_cpp.txt"

    assert(os.path.isfile(file_path))

    data_in = np.loadtxt(file_path, delimiter=',', dtype=float, skiprows=1)

Then, remake the plot for transition volume

In [None]:
if plot_cpp_diagram_results:

    from scipy.interpolate import NearestNDInterpolator

    ## for a global mesh
    P_values_1 = np.linspace(0.0, 30e9, 200)  # Pressure in Pascals
    T_values_1 = np.linspace(273.15, 1873.15, 100)  # Temperature in Kelvin

    # scaling factors
    P0 = 1e9 # Pa
    T0 = 100.0 # K
    t0 = 10000 * year # s

    # make a new grid
    T_grid_1, P_grid_1 = np.meshgrid(T_values_1, P_values_1)

    T_flat_1 = T_grid_1.flatten()
    P_flat_1 = P_grid_1.flatten()

    # compute equilibrium values
    P_eq_values_1 = Meta.compute_eq_P(PT410, T_values_1)

    # perform interpolation
    interpolator = NearestNDInterpolator(
        np.column_stack((data_in[:, 1]/T0, data_in[:, 0]/P0, data_in[:, 2]/t0)),
        data_in[:, 7]
    )

Plot at a given time

In [None]:
if plot_cpp_diagram_results:

    from scipy.ndimage import zoom
    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (273.15, 1773.15)
    x_tick_interval = 200.0   # tick interval along x
    y_lim = (0.0, 30.0)
    y_tick_interval = 5.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # constant time
    t_constant = 5e4*year     # Time in seconds

    # perform interpolation 1: near neighber
    V_flat_1 = interpolator(T_flat_1/T0, P_flat_1/P0, np.full(T_flat_1.shape, t_constant/t0))

    # # perform interpolation 2: IDW
    # V_flat = idw_interpolation(np.column_stack((data["T"]/T0, data["P"]/P0, data["t"]/t0)), data["V"],\
    #                             np.column_stack((T_flat / T0, P_flat / P0, np.full(T_flat.shape, t_constant / t0))), k=5, power=2)

    # reshape to V_grid
    V_grid = V_flat_1.reshape(T_grid_1.shape)

    # Upsample the V_grid for smoothing
    # Define finer grid based on zoom factor
    zoom_factor = 2  # Upscaling factor
    V_smooth_grid = zoom(V_grid, zoom_factor, order=3)  # Cubic spline interpolation

    T_fine = np.linspace(np.min(T_values_1), np.max(T_values_1), V_smooth_grid.shape[1])
    P_fine = np.linspace(np.min(P_values_1), np.max(P_values_1), V_smooth_grid.shape[0])
    T_fine_grid, P_fine_grid = np.meshgrid(T_fine, P_fine)

    # Create subplots
    fig, axes = plt.subplots(1, 2, figsize=(14, 6), constrained_layout=True)

    # Plot original coarse grid
    h1 = axes[0].pcolormesh(T_grid_1, P_grid_1 / 1e9, V_grid, cmap="viridis", shading="auto")
    axes[0].contour(T_grid_1, P_grid_1 / 1e9, V_grid, (0.4, 0.8), colors=['tab:gray', 'k'], linestyles="--")
    axes[0].plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    axes[0].grid()

    axes[0].set_xlim(x_lim)
    axes[0].set_ylim(y_lim)

    axes[0].xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    axes[0].xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    axes[0].yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    axes[0].yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    axes[0].set_xlabel("Temperature (K)")
    axes[0].set_ylabel("Pressure (GPa)")
    axes[0].set_title("Original Coarse Grid")
    fig.colorbar(h1, ax=axes[0], label="V")

    # Plot smoothed grid
    h2 = axes[1].pcolormesh(T_fine_grid, P_fine_grid / 1e9, V_smooth_grid, cmap="viridis", shading="auto", vmin=0.0, vmax=1.0)
    axes[1].contour(T_fine_grid, P_fine_grid / 1e9, V_smooth_grid,\
                    (0.6321, 0.864), # 1 - exp(-1), 1 - exp(-2)
                    colors=['tab:gray', 'k'], linestyles="--")
    axes[1].plot(T_values_1, P_eq_values_1/1e9, "-.")  # plot the equilibrium phase boundary

    axes[1].grid()

    axes[1].set_xlim(x_lim)
    axes[1].set_ylim(y_lim)

    axes[1].xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    axes[1].xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    axes[1].yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    axes[1].yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    axes[1].set_xlabel("Temperature (K)")
    axes[1].set_ylabel("Pressure (GPa)")
    axes[1].set_title("Smoothed Grid (Upsampled)")
    fig.colorbar(h2, ax=axes[1], label="V")

    # Adjust spine thickness for this plot
    for ax in axes:
        for spine in ax.spines.values():
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    plt.show()

    # save figure
    fig_path = os.path.join(results_dir, "PTV_cpp_t%.2f.png" % (t_constant/1e6/year))
    fig.savefig(fig_path)

    print("Saved figure %s" % fig_path)

    # fig_path_pdf = os.path.join(results_dir, "PTV_cpp_t%.2f.pdf" % (t_constant/1e6/year))
    # fig.savefig(fig_path_pdf)

    # print("Saved figure %s" % fig_path_pdf)


    # Reset rcParams to defaults

    rcdefaults()

# Run analytic MOW contents from aspect P, T conditions

These blocks follow these steps:
* Run a test case
* Export the dataset from the pvtu files
* Apply a nearneighbor interpolation
* Plot the colormap and the contours

## Run aspect

Key Steps

- **Input Verification**:
  - Ensures the provided `case_dir` directory exists.
  - Validates the presence of the parameter file (`case.prm`) required for the ASPECT run.

- **Execution**:
  - Runs the ASPECT executable using `subprocess.run`.
  - Captures both `stdout` (standard output) and `stderr` (standard error) for validation.

- **Output Validation**:
  - Verifies that the expected wallclock time summary appears in the output log.
  - Ensures that the error stream (`stderr`) is empty, confirming a successful execution.

In [None]:
is_run_aspect = False

if is_run_aspect:
    import re
    import json
    from shutil import copy
    from utils.dealii_param_parser import parse_parameters_to_dict, save_parameters_from_dict
    from utils.world_builder_file_parser import find_feature_by_name, update_or_add_feature

    # Define paths to the ASPECT executable and the case directory
    aspect_executable = "/home/lochy/Softwares/aspect/build_master_TwoD/aspect"
    group_dir = "/mnt/lochz/ASPECT_DATA/TwoDSubduction/MO_kinetics_test"
    template_dir = os.path.join(group_dir, "test_case_template")  # Directory containing templates
    # case_dir = "/mnt/lochz/ASPECT_DATA/TwoDSubduction/MO_kinetics_test/test_case_ini"  # Initial case, larger geometry

    # Ensure the case directory exists
    assert(os.path.isdir(template_dir))

    prm_template_path = os.path.join(template_dir, "case.prm")
    assert(os.path.isfile(prm_template_path))

    wb_template_path = os.path.join(template_dir, "case.wb")
    assert(os.path.isfile(wb_template_path))

    # Case setups
    wb_sp_velocity = 0.08 # m/yr
    wb_sp_age = 140e6 # yr
    wb_trench_x = 200e3 # m
    wb_ridge_x = wb_trench_x - wb_sp_age * wb_sp_velocity

    case_dir = os.path.join(group_dir, "test_case_sp%.1f_v%.1e" % (wb_sp_age/1e6, wb_sp_velocity*100))  # Directory containing templates
    if not os.path.isdir(case_dir):
        os.mkdir(case_dir)

    # Modify the template
    # Also read important parameters like the size of the model

    with open(prm_template_path, 'r') as file:
        params_dict = parse_parameters_to_dict(file)

    params_dict["Output directory"] = os.path.join(case_dir, "output")
    params_dict["World builder file"] = os.path.join(case_dir, "case.wb")

    x_extent = float(params_dict["Geometry model"]["Box"]["X extent"])
    y_extent = float(params_dict["Geometry model"]["Box"]["Y extent"])

    with open(wb_template_path, 'r') as fin:
        wb_dict = json.load(fin)
        
    slab_dict = find_feature_by_name(wb_dict, "Slab") # Extract the "Slab" feature from the World Builder data
    sp_dict = find_feature_by_name(wb_dict, "Subducting plate")

    slab_dict["coordinates"] = [[wb_trench_x, -1000.0], [wb_trench_x, 1000.0]]
    slab_dict["temperature models"][0]["plate velocity"] = wb_sp_velocity
    slab_dict["temperature models"][0]["ridge coordinates"][0] = [[wb_ridge_x , 1000.0], [wb_ridge_x , 1000.0]]

    sp_dict["temperature models"][0]["plate age"] = wb_sp_age 

    wb_dict = update_or_add_feature(wb_dict, "Slab", slab_dict)
    wb_dict = update_or_add_feature(wb_dict, "Subducting plate", sp_dict)

    # Define paths to the parameter file and world builder file within the case directory

    prm_path = os.path.join(case_dir, "case.prm")
    wb_path = os.path.join(case_dir, "case.wb")

    with open(prm_path, 'w') as output_file:
        save_parameters_from_dict(output_file, params_dict)

    with open(wb_path, 'w') as fout:
        json.dump(wb_dict, fout)

    # Ensure the parameter file exists
    assert(os.path.isfile(prm_path))
    assert(os.path.isfile(wb_path))

    # Run the ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr == "")

## Export data and plot

This script reads simulation data from a `.pvtu` file, processes it using VTK and NumPy, and sets up interpolators for various physical fields such as temperature, pressure, and resolution.

In [None]:
if is_run_aspect:

    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from utils.vtk_utilities import calculate_resolution
    import time
    import numpy as np
    from scipy.interpolate import NearestNDInterpolator

    # Define the input file path and field names to extract
    pvtu_file = os.path.join(case_dir, "output", "solution", "solution-00005.pvtu")
    field_names = ["T", "p"]  # Field names to extract: temperature (T) and pressure (p)

    # Read the pvtu file
    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    start = time.time()

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    end = time.time()
    print("Reading files takes %.2f s" % (end - start))
    start = end

    # Calculate resolution for each cell or point in the grid
    resolutions = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2f s" % (end - start))
    start = end

    # Construct a vtkPolyData object to hold points and cell information
    i_poly_data = vtk.vtkPolyData()
    i_poly_data.SetPoints(points)  # Add points to the PolyData object
    i_poly_data.SetPolys(cells)  # Add cell connectivity to the PolyData object

    # Add point data fields to the vtkPolyData
    for idx, field_name in enumerate(field_names):
        array = point_data.GetArray(field_name)  # Retrieve the field array
        if array:
            if idx == 0:  # The first field becomes Scalars
                i_poly_data.GetPointData().SetScalars(array)
            else:  # Additional fields are added as arrays
                i_poly_data.GetPointData().AddArray(array)
        else:
            print(f"Warning: Field {field_name} not found.")  # Warn if field is missing

    # Validate that points were successfully added to the vtkPolyData object
    noP = i_poly_data.GetNumberOfPoints()
    if noP == 0:
        raise ValueError("No points were added to i_poly_data!")

    end = time.time()
    print("Constructing polydata takes %.2f s" % (end - start))
    start = end

    # Export data to NumPy arrays for easier processing
    points_np = vtk_to_numpy(i_poly_data.GetPoints().GetData())  # Convert points to NumPy
    Ts = vtk_to_numpy(i_poly_data.GetPointData().GetArray("T"))  # Temperature array
    Ps = vtk_to_numpy(i_poly_data.GetPointData().GetArray("p"))  # Pressure array

    end = time.time()
    print("Exporting data to NumPy arrays takes %.2f s" % (end - start))

    # Extract 2D coordinates (x, y) from the points
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    # Create interpolators for temperature, pressure, and resolution
    interpolator = NearestNDInterpolator(points_2d, Ts)  # Interpolator for temperature
    interpolator_P = NearestNDInterpolator(points_2d, Ps)  # Interpolator for pressure
    interpolator_r = NearestNDInterpolator(points_2d, resolutions)  # Interpolator for resolution

Interpolate to a regular grid

This block interpolates simulation data (e.g., temperature, resolution) onto a regular 2D grid and visualizes it using various plots, including colormaps and contour plots.

In [None]:
if is_run_aspect:

    # Interpolate to regular grid
    # Assuming `points` is an (n, 3) array of (x, y, z) coordinates
    # and `Ts` is a (n,) array with corresponding T values

    import numpy as np
    from scipy.interpolate import NearestNDInterpolator

    # Define the interval for the grid (in meters)
    interval = 5e3  # 10 km grid interval

    # Determine the bounding box of the 2D points
    x_min, y_min = np.min(points_2d, axis=0)
    x_max, y_max = np.max(points_2d, axis=0)

    # Define a regular grid within the bounding box
    x_grid = np.arange(x_min, x_max, interval)
    y_grid = np.arange(y_min, y_max, interval)
    xv, yv = np.meshgrid(x_grid, y_grid, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([xv.ravel(), yv.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolator(grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(xv.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolator_P(grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(xv.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolator_r(grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(xv.shape)

    end = time.time()
    print("Interpolating to regular grid takes %.2f s" % (end - start))
    start = end

## Read WorldBuilder information

- Load World Builder Data:
Reads the slab definition from a JSON file.
Extracts slab segments, trench location, and subduction velocity.
Segment Analysis:

- Computes segment lengths, depths, and dip angles.
Determines distances to the slab curve for visualization purposes.
Visualization:

In [None]:
if is_run_aspect:

    import json

    # Load the World Builder (wb) configurations
    # Extract the "Slab" feature from the World Builder data
    # Process the slab segments to compute relevant properties
    # Calculate subduction times
    slab_dict = find_feature_by_name(wb_dict, "Slab")

    slab_segments = slab_dict["segments"]  # Retrieve slab segment definitions
    trench_x = slab_dict["coordinates"][0][0]  # Extract trench x-coordinate
    subduct_velocity_cm_yr = slab_dict["temperature models"][0]["plate velocity"] * 100.0 # Convert velocity to cm/yr
    subduct_velocity = slab_dict["temperature models"][0]["plate velocity"] / year  # Convert velocity to m/s

    lengths, depths, dip_angles, Xs = process_segments(slab_segments, n_spacing=200)  # Segment analysis

    sub_ts = lengths / subduct_velocity

## Processing the data grid

Key Steps

1. **Depth Grid Calculation**:
   - Converts $y$-coordinates to depths ($\text{depth}_{\text{grid}} = y_{\text{max}} - y_v$).
   - Interpolates segment lengths along the depth grid.

2. **Slab Internal Masking**:
   - Defines the slab's internal region based on distances from the slab surface (-5 km to 100 km).

3. **Subduction Time Calculation**:
   - Computes subduction time ($t_{\text{sub}}$) for points inside the slab using:
     $$ t_{\text{sub}} = \frac{\text{length}}{\text{subduction velocity}} $$

4. **Contour Generation**:
   - Generates contours at specified distances (e.g., 5 km) from the slab surface.

Constants and Assumptions

- **Slab Internal Mask**:
  - Distance range for slab internals: $-5 \, \text{km} \leq \text{distance} \leq 100 \, \text{km}$.
- **Contour Distance**:
  - A single contour at $5 \, \text{km}$ from the slab surface.

In [None]:
if is_run_aspect:

    # Derive the slab internal points, create contours on distance to the slab surface
    # and plot subduction time (`t_sub`) along profiles
    # 1. Compute distances to the slab curve
    # 2. Compute the depth grid
    # 3. Interpolate length values along the depth grid
    # 4. Create a mask for slab internals based on distance to the slab surface
    # 5. Compute subduction time (`t_sub`) for points inside the slab
    distance_v = distances_to_curve(Xs + trench_x, y_max - depths, xv.ravel(), yv.ravel())
    distance_grid = distance_v.reshape(xv.shape)  # Reshape distances to grid shape

    depth_grid = y_max - yv  # Convert y-coordinates to depths
    depth_grid_flat = depth_grid.ravel()  # Flatten depth grid for interpolation

    length_grid_flat = np.interp(depth_grid_flat, depths, lengths)  # Interpolate lengths at depth values
    length_grid = length_grid_flat.reshape(depth_grid.shape)  # Reshape back to grid format

    mask_slab1 = (distance_grid >= -5e3) & (distance_grid <= 100e3)  # Slab internal region: -5 km to 100 km
    xv_slab = xv[mask_slab1]
    yv_slab = yv[mask_slab1]

    t_sub_grid = np.full(xv.shape, float("inf"))  # Initialize with infinity for points outside the slab
    t_sub_grid[mask_slab1] = length_grid[mask_slab1] / subduct_velocity  # Compute `t_sub` where the mask applies

Diagnose a profile with a distance to the slab surface

In [None]:
if is_run_aspect:

    # Parameters for contouring
    contour_distance = [40e3]  # Contour distance of 5 km

    # Generate contours at specified distances from the slab surface
    foo_contour_Xs, foo_contour_Ys = offset_curve(Xs + trench_x, y_max - depths, contour_distance)

    foo_contour_depths = y_extent - foo_contour_Ys

    foo_grid_points_2d = np.vstack([foo_contour_Xs.ravel(), foo_contour_Ys.ravel()]).T

    foo_contour_Ts = interpolator(foo_grid_points_2d)
    foo_contour_Ps = interpolator_P(foo_grid_points_2d)

    foo_contour_Ps_eq = (foo_contour_Ts - PT410["T"]) * PT410["cl"] + PT410["P"]
    foo_contour_Ts_eq = (foo_contour_Ps - PT410["P"]) / PT410["cl"] + PT410["T"]
    foo_mask_eq = (foo_contour_Ps > foo_contour_Ps_eq)
        
    foo_contour_lengths = np.interp(foo_contour_depths, depths, lengths)  # Interpolate lengths at depth values
    foo_contour_ts = foo_contour_lengths / subduct_velocity

    o_file = os.path.join(results_dir, "foo_contour_data.txt")

    # Make sure all arrays are 1D and of the same length
    odata = np.column_stack((foo_contour_Ps, foo_contour_Ts, foo_contour_ts))

    # Save to text file with headers
    np.savetxt(o_file, odata, header="Ps Ts ts", fmt="%.6e")

    print("Saved file %s" % o_file)


1. Plot the initial conditions

A. temperature, subduction time, mesh resolution

(Editing in AI)
append the color bar of subduction time from the plot at gs([1, 0]);

add vectors of velocity in the slab internal

B. profile properties

(Editing in AI)
combine the results on Mo extent in the following block
append the annotation of subduction time to Y axis from printed outputs

In [None]:
if is_run_aspect:

    # Initialize plots

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 1000.0)
    x_tick_interval = 250   # tick interval along x
    y_lim = (0.0, 1000.0)
    y_tick_interval = 250  # tick interval along y
    v_lim = (0.0, 20000)
    v_lim3 = (0.0, 20.0)
    v_level = 50  # number of levels in contourf plot
    v_tick_interval = 5000.0  # tick interval along v
    v_tick_interval3 = 5.0  # tick interval along v
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure with a 2x2 grid layout
    fig = plt.figure(figsize=(12, 10), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)

    # First plot Plot a colormap of T with contour
    # Then a plot of sub_t inside the slab
    # Add contour lines of T to the colormap
    # Add quiver plot of assumed subducting velocity
    levels = np.linspace(v_lim[0], v_lim[1], v_level)
    ticks=np.arange(v_lim[0], v_lim[1], v_tick_interval)
    ax = fig.add_subplot(gs[0, 0])
    color_map = ax.contourf(xv/1e3, yv/1e3, resolutions_grid,  vmin=v_lim[0], vmax=v_lim[1], levels=levels, cmap="plasma_r")  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="Resolution")  # Add colorbar
    cbar.set_ticks(ticks)

    levels3 = np.linspace(v_lim3[0], v_lim3[1], v_level)
    ticks3=np.arange(v_lim3[0], v_lim3[1], v_tick_interval3)
    color_map3 = ax.contourf(xv/1e3, yv/1e3, t_sub_grid/year/1e6, vmin=v_lim3[0], vmax=v_lim3[1], levels=levels3, cmap="viridis") # sub_t
    # cbar3 = fig.colorbar(color_map3, ax=ax, label="time from trench (Ma)")
    # cbar3.set_ticks(ticks3)

    contours = ax.contour(
        xv/1e3, yv/1e3, T_grid-273.15, levels=np.arange(100.0, 1473.15 + 100.0, 200.0), colors="black", linewidths=0.5
    )
    ax.clabel(contours, inline=True, fontsize=8, fmt="%.1f")  # Add labels to the contours

    # Quiver plot: maybe append at a last step of editing
    # dip_angle_in_slab = np.interp(y_extent-yv_slab, depths, dip_angles)
    # vx = subduct_velocity_cm_yr * np.cos(dip_angle_in_slab)
    # vy = -subduct_velocity_cm_yr * np.sin(dip_angle_in_slab)
    # skip = (slice(None, None, 40)) # plot with interval by skipping
    # Q_map = ax.quiver(xv_slab[skip]/1e3, yv_slab[skip]/1e3, vx[skip], vy[skip], angles='xy', color="black", scale=400.0*scaling_factor)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.plot((Xs + trench_x)/1e3, (y_max - depths)/1e3, ".", markersize=2*scaling_factor, color=default_colors[0]) # slab surface
    mask_l = np.abs((distance_grid - 100e3)) < 2e3
    ax.plot(xv[mask_l]/1e3, yv[mask_l]/1e3, ".", markersize=2*scaling_factor, color=default_colors[2]) # slab surface
    ax.plot(foo_contour_Xs/1e3, foo_contour_Ys/1e3, '.',  markersize=2*scaling_factor, color=default_colors[3]) # internal profile

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Y (km)")

    # Second plot: properties along the profile, P, T
    T_lim = (400.0, 1200.0)
    P_lim = (5.0, 25.0)
    T_tick_interval = 100.0
    P_tick_interval = 5.0
    depth_lim = (300.0, 700)
    depth_tick_inverval = 100.0

    ax4 = fig.add_subplot(gs[0, 1])
    ax4.plot(foo_contour_Ts-273.15, foo_contour_depths/1e3, linewidth=4*scaling_factor, color=default_colors[1], label="T")


    ax4.set_xlim(T_lim)
    ax4.set_ylim(depth_lim)

    ax4.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
    ax4.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
    ax4.yaxis.set_major_locator(MultipleLocator(depth_tick_inverval))
    ax4.yaxis.set_minor_locator(MultipleLocator(depth_tick_inverval/(n_minor_ticks+1)))

    ax5 = ax4.twiny()
    ax5.plot(foo_contour_Ps/1e9, foo_contour_depths/1e3, color=default_colors[1], label="P")

    ax5.set_xlim(P_lim)
    ax5.set_ylim(depth_lim)

    ax5.xaxis.set_major_locator(MultipleLocator(P_tick_interval))
    ax5.xaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

    ax4.invert_yaxis()
    ax4.grid()

    ax5.legend()

    ax4.set_xlabel(r"Temperature ($^{\circ}C$)")
    ax5.set_xlabel("Pressure (GPa)")
    ax4.set_ylabel("Depth (km)")

    q_depths = [400e3, 500e3, 600e3, 700e3] # print the time from subduction 
    for i, q_depth in enumerate(q_depths):
        q_t_sub = np.interp(q_depth, depths, lengths) / subduct_velocity
        print("Depth = %.1f km, t_sub = %.4e Ma" % (q_depth/1e3, q_t_sub/year/1e6))

    # Additional plot: slab morphology
    # Create the first subplot for lengths
    # Create a twin y-axis for dip angles
    # Add legends and customize the plot
    ax1 = fig.add_subplot(gs[1, 1])
    color1 = 'tab:blue'
    ax1.plot(depths/1e3, lengths/1e3, label="Lengths", color=color1, linewidth=2)  # Plot lengths

    ax1.set_xlim([0.0, 1000.0])
    ax1.set_ylim([0.0, 1400.0])

    x_tick_interval1 = 200.0; y_tick_interval1 = 200.0
    ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval1))
    ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval1/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval1))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval1/(n_minor_ticks+1)))

    ax1.set_xlabel("Depth (m)")
    ax1.set_ylabel("Lengths (m)", color=color1)
    ax1.tick_params(axis='y', labelcolor=color1)

    ax2 = ax1.twinx()
    color2 = 'tab:orange'
    ax2.plot(depths/1e3, np.degrees(dip_angles), label="Dip Angles", color=color2, linewidth=2.0*scaling_factor, linestyle="--")  # Plot dip angles

    ax2.set_ylim([0.0, 90.0])

    y_tick_interval2 = 10.0
    ax2.yaxis.set_major_locator(MultipleLocator(y_tick_interval2))
    ax2.yaxis.set_minor_locator(MultipleLocator(y_tick_interval2/(n_minor_ticks+1)))

    ax2.set_ylabel("Dip Angles (°)", color=color2)
    ax2.tick_params(axis='y', labelcolor=color2)

    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax1.set_title("Lengths and Dip Angles vs Depths")
    ax1.grid(True, linestyle="--", alpha=0.5)

    # plot the subduction time

    ax3 = fig.add_subplot(gs[1, 0])

    ax3.plot((Xs + trench_x)/1e3, (y_max - depths)/1e3, ".", markersize=2*scaling_factor, color=default_colors[0]) # slab surface
    ax3.plot(xv[mask_l]/1e3, yv[mask_l]/1e3, ".", markersize=2*scaling_factor, color=default_colors[2]) # slab surface

    levels3 = np.linspace(v_lim3[0], v_lim3[1], v_level)
    ticks3=np.arange(v_lim3[0], v_lim3[1], v_tick_interval3)
    color_map3 = ax3.contourf(xv/1e3, yv/1e3, t_sub_grid/year/1e6, vmin=v_lim3[0], vmax=v_lim3[1], levels=levels3, cmap="viridis") # sub_t
    cbar3 = fig.colorbar(color_map3, ax=ax3, label="time from trench (Ma)")
    cbar3.set_ticks(ticks3)

    ax3.set_xlim(x_lim)
    ax3.set_ylim(y_lim)

    ax3.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax3.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax3.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax3.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax3.set_xlabel("X (km)")
    ax3.set_ylabel("Y (km)")

    # Show figure
    plt.show()

    # Save figure to a PDF file
    pdf_path = os.path.join(case_dir, "T.pdf")
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

    end = time.time()
    print("Plotting color map takes %.2f s" % (end - start))
    start = end

    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax2.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax3.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax4.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax5.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Reset rcParams to defaults

    rcdefaults()

## Derive MO contents for the profile defined previously

This script first queries into the position of the extracted profile with a distance to the surface. Then compute both the equilibrium phases and metastable phases

Different methods

1. Hosoya_2005, following the analytical method in Hosoya_2005
2. Blocking temperature of 725 C, from Quiteros_Sobolev_2012
3. Kinetics defined by metastable.py

In [None]:
if is_run_aspect:
    ## Method 1: MO kinetics with a model
    # MO_method = "hosoya_2005"; blocking_T = None; nucleation_type=1
    # n_t = None; n_span = None
    
    ## Method 2: blocking temperature
    # MO_method = "blockT";
    # blocking_T = 725.0 + 273.15; nucleation_type=1;
    # n_t=None; n_span = None

    ## Method 3: Using the kinetic relations
    MO_method = "kinetics"; blocking_T = None; nucleation_type=1; 
    n_t = 1; n_span = 10

    Coh = 150.0 # wt% for methods with mo kinetics
    d_ol = 10e-3 # m background grain size for methods with mo kinetics


    output_dir = os.path.join(case_dir, MO_method)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, "dol_%.2e_coh_%.1f" % (d_ol, Coh))
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    # equilibrium contents
    foo_contents_wl_eq = np.zeros(foo_contour_Ts.shape)
    foo_contents_wl_eq[foo_mask_eq] = 1.0

    # metastable contents
    foo_contents_wl_mo = None
    if MO_method == "blockT":
        foo_contents_wl_mo = foo_contents_wl_eq.copy()
        foo_mask_mo = (foo_contour_Ts < blocking_T)
        foo_contents_wl_mo[foo_mask_mo] = 0.0
    elif MO_method == "hosoya_2005":
        # initiate the kinetics class 
        _constants, _ = Meta.get_kinetic_constants(nucleation_type)
        pTKinetics = Meta.PTKinetics(_constants)

        growth_rates = np.zeros(foo_contour_Ps.shape)
        mask = (foo_contour_Ps > foo_contour_Ps_eq)
        growth_rates[mask] = pTKinetics.growth_rate_interface_P2(foo_contour_Ps[mask], foo_contour_Ts[mask],\
                                                                foo_contour_Ps_eq[mask], foo_contour_Ts_eq[mask], Coh)
        foo_contents_wl_mo = MO_Vfraction_classic(growth_rates, foo_contour_ts, d_ol)
    elif MO_method == "kinetics":
        # initiate the kinetics class 
        _, _constants1 = Meta.get_kinetic_constants(nucleation_type)
        Mo_Kinetics = Meta.MO_KINETICS(_constants1)
        Mo_Kinetics.set_initial_grain_size(d_ol)

        Mo_Kinetics.set_PT_eq(PT410['P'], PT410['T'], PT410['cl'])
        Mo_Kinetics.link_and_set_kinetics_model(Meta.PTKinetics)

        # set metastable contents along the profile
        foo_contents_wl_mo = np.zeros(foo_contour_Ps.size)
        for i in range(foo_contour_Ps.size-1):
            # parse variables:
            # P, T
            # t0, t1 - start and end of the time step
            P = foo_contour_Ps[i]
            T = foo_contour_Ts[i]
            t0 = foo_contour_ts[i]
            t1 = foo_contour_ts[i+1]
            Mo_Kinetics.set_kinetics_fixed(P, T, Coh)

            # solve the ODEs
            if i == 0:
                _initial = None
            else:
                _initial = results[-1, :]
            results = Mo_Kinetics.solve(P, T, t0, t1, n_t, n_span, initial=_initial)
            # results = Mo_Kinetics.solve(P, T, 0, t1, n_t, n_span)
            foo_contents_wl_mo[i+1] = results[-1, 5]
            
    else:
        raise NotImplementedError

In [None]:
if is_run_aspect:

    # Initialize plots

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 1.0) # volume fraction
    x_tick_interval = 0.25   # tick interval along x
    # x_lim1 = (0.0, 1.0) # time in Ma
    # x_tick_interval1 = 0.25  
    y_lim = (300.0, 700.0)
    y_tick_interval = 100  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a new figure and axis
    fig = plt.figure(figsize=(16, 5))
    gs = gridspec.GridSpec(1, 2)

    ax = fig.add_subplot(gs[0, 0])

    # Plot contents of wl on the primary y-axis
    ax.plot(foo_contents_wl_eq, foo_contour_depths/1e3, linestyle="--", color=default_colors[1], label="Contents Wd EQ")
    ax.plot(foo_contents_wl_mo, foo_contour_depths/1e3, linestyle="-", color=default_colors[1], label="Contents Wd MO (%s)" % MO_method)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("Transformed Volume")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    ax.legend()

    ax.grid()
    
    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot subduction time
    ax1 = fig.add_subplot(gs[0, 1])
    
    ax1.plot(foo_contour_ts/year/1e6, foo_contour_depths/1e3, linestyle="-", color=default_colors[2], label="Subducting Time (%s)" % MO_method)

    # ax1.set_xlim(x_lim1)
    ax1.set_ylim(y_lim)

    # ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    # ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax1.set_xlabel("Time (Ma)")
    ax1.set_ylabel("Depth (km)")

    ax1.invert_yaxis()

    ax1.legend()

    ax1.grid()
    
    # Adjust spine thickness for this plot
    for spine in ax1.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Save figure
    pdf_path = os.path.join(output_dir, "profile_Wd_contents_%s_%.1fkm.pdf" % (MO_method, contour_distance[0]/1e3))
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

    # Reset rcParams to defaults
    rcdefaults()

## Next we conduct a numerical test on the kinetic variables

In [None]:
if is_run_aspect:

    # test parameters 
    n_t = 1
    n_span_array = [2, 3, 5, 10, 20]

    # array to save test results
    foo_contents_wl_mo_array = []    

    # run profile analysis on different parameters 
    for i, n_span in enumerate(n_span_array):
        # initiate the kinetics class 
        _, _constants1 = Meta.get_kinetic_constants(nucleation_type)
        Mo_Kinetics = Meta.MO_KINETICS(_constants1)
        Mo_Kinetics.set_initial_grain_size(d_ol)

        Mo_Kinetics.set_PT_eq(PT410['P'], PT410['T'], PT410['cl'])
        Mo_Kinetics.link_and_set_kinetics_model(Meta.PTKinetics)

        # set metastable contents along the profile
        foo_contents_wl_mo = np.zeros(foo_contour_Ps.size)
        for i in range(foo_contour_Ps.size-1):
            # parse variables:
            # P, T
            # t0, t1 - start and end of the time step
            P = foo_contour_Ps[i]
            T = foo_contour_Ts[i]
            t0 = foo_contour_ts[i]
            t1 = foo_contour_ts[i+1]
            Mo_Kinetics.set_kinetics_fixed(P, T, Coh)

            # solve the ODEs
            if i == 0:
                _initial = None
            else:
                _initial = results[-1, :]
            results = Mo_Kinetics.solve(P, T, t0, t1, n_t, n_span, initial=_initial)
            # results = Mo_Kinetics.solve(P, T, 0, t1, n_t, n_span)
            foo_contents_wl_mo[i+1] = results[-1, 5]
        
        foo_contents_wl_mo_array.append(foo_contents_wl_mo)

    # assess the differences
    # use the last result as standard
    i_query = 238
    print("Depth = ", foo_contour_depths[i_query])  # output
    print("ts = %.4e Ma" % (foo_contour_ts[i_query]/year/1e6))
    
    foo_contents_wl_mo_std = foo_contents_wl_mo_array[-1] 
    std_result = foo_contents_wl_mo_std[i_query]

    foo_contents_wl_mo_q_array = [] # array to solve the query results
    foo_relative_error_q_array = []
    for i, foo_contents_wl_mo in enumerate(foo_contents_wl_mo_array):

        query_result = foo_contents_wl_mo[i_query]
        relative_error = np.abs((query_result - std_result)/std_result)

        foo_contents_wl_mo_q_array.append(query_result)
        foo_relative_error_q_array.append(relative_error)
        
        print("    i = ", i) # output
        print("    foo_contents_wl_mo[%d] = " % i_query)
        print(foo_contents_wl_mo[i_query])

In [None]:
if is_run_aspect:

    # Initialize plots

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 20) # volume fraction
    x_tick_interval = 5   # tick interval along x
    y_lim1 = (0.0, 0.2) # Relative error
    y_tick_interval1 = 0.05
    y_lim = (0.4, 0.5)
    y_tick_interval = 0.025  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Plot the solution and relative error in twin axis
    fig, ax = plt.subplots(figsize = (8*scaling_factor, 5*scaling_factor))

    ax.plot(n_span_array, foo_contents_wl_mo_q_array, color=default_colors[0])

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))
    
    ax.set_xlabel("stepping")
    ax.set_ylabel("Transformed Volume")

    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    ax1 = ax.twinx()
    
    ax1.plot(n_span_array, foo_relative_error_q_array, color=default_colors[1])
    
    ax1.set_xlim(x_lim)
    ax1.set_ylim(y_lim1)

    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval1))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval1/(n_minor_ticks+1)))
    
    ax1.set_ylabel("Relative Error")

    ax.grid() # turn on grid

    # Save figure
    pdf_path = os.path.join(output_dir, "Wd_contants_convergence_%.1fkm.pdf" % (contour_distance[0]/1e3))
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

    # Reset rcParams to defaults
    rcdefaults()

## Derive contents with multiple curvatures and interpolate back to a 2d grid

Plot the results in a 2-d mesh

In [None]:
if is_run_aspect:

    from scipy.spatial import cKDTree
    from scipy.interpolate import griddata

    def interpolate_grid_from_points(solution_Xs, solution_Ys, solution_Zs, spacing, distance_threshold):
        # Create a 2D regular grid
        x_min, x_max = solution_Xs.min(), solution_Xs.max()
        y_min, y_max = solution_Ys.min(), solution_Ys.max()

        # Define grid resolution
        grid_x, grid_y = np.meshgrid(
            np.arange(x_min, x_max, spacing),
            np.arange(y_min, y_max, spacing)
        )

        # Flatten the grid points for easier querying
        grid_points = np.column_stack((grid_x.ravel(), grid_y.ravel()))

        # Build a KDTree with the original data points
        tree = cKDTree(np.column_stack((solution_Xs, solution_Ys)))

        # Query nearest neighbors for each grid point
        distances, indices = tree.query(grid_points, distance_upper_bound=distance_threshold)

        # Initialize grid values with NaN
        grid_z = np.full(grid_points.shape[0], np.nan)

        # Assign values only for points within the threshold
        valid_mask = distances <= distance_threshold
        grid_z[valid_mask] = solution_Zs[indices[valid_mask]]

        # Reshape the result to match the grid
        grid_z = grid_z.reshape(grid_x.shape)

        return grid_x, grid_y, grid_z


    # Generate contours at specified distances from the slab surface
    # Specify a distance in between
    contour_distance_array = np.arange(0, 100e3, 1e3)

    solution_Xs = np.array([])
    solution_Ys = np.array([])
    solution_wl_eq = np.array([])
    solution_wl_mo = np.array([])
    contour_traces = [[] for i in range(contour_distance_array.size)]

    # Generate Grid results for equilibrium phase transitions
    contents_wl_eq_grid = np.zeros(xv.shape)
    P_eq_grid = (T_grid - PT410["T"]) * PT410["cl"] + PT410["P"]
    mask_eq_grid = ( P_grid > P_eq_grid)
    contents_wl_eq_grid[mask_eq_grid] = 1.0

    # Compute profile-wise results for metastable phase transitions
    # First, initiate classes
    pTKinetics = None; Mo_Kinetics = None
    if MO_method == "blockT":
        pass
    elif MO_method == "hosoya_2005":
        _constants, _ = Meta.get_kinetic_constants(nucleation_type)
        pTKinetics = Meta.PTKinetics(_constants)
    elif MO_method == "kinetics":
        _, _constants1 = Meta.get_kinetic_constants(nucleation_type)
        Mo_Kinetics = Meta.MO_KINETICS(_constants1)
        Mo_Kinetics.set_initial_grain_size(d_ol)

        Mo_Kinetics.set_PT_eq(PT410['P'], PT410['T'], PT410['cl'])
        Mo_Kinetics.link_and_set_kinetics_model(Meta.PTKinetics)
    else:
        raise NotImplementedError()

    for i_prof in range(contour_distance_array.size):
        contour_distance_1 = contour_distance_array[i_prof]
        contour_Xs, contour_Ys = offset_curve(Xs + trench_x, y_max - depths, contour_distance_1)

        grid_points_2d = np.vstack([contour_Xs.ravel(), contour_Ys.ravel()]).T

        contour_Ts = interpolator(grid_points_2d)
        contour_Ps = interpolator_P(grid_points_2d)

        contour_Ps_eq = (contour_Ts - PT410["T"]) * PT410["cl"] + PT410["P"]
        contour_Ts_eq = (contour_Ps - PT410["P"]) / PT410["cl"] + PT410["T"]

        mask_eq = (contour_Ps > contour_Ps_eq)

        contour_ts = lengths / subduct_velocity

        # equilibrium contents
        contents_wl_eq = np.zeros(contour_Ts.shape)
        contents_wl_eq[mask_eq] = 1.0

        # metastable contents
        contents_wl_mo = None
        if MO_method == "blockT":
            contents_wl_mo = contents_wl_eq.copy()
            mask_mo = (contour_Ts < blocking_T)
            contents_wl_mo[mask_mo] = 0.0
        elif MO_method == "hosoya_2005":
            # compute growth rate and then solve for metastable contents
            growth_rates = np.zeros(contour_Ps.shape)
            mask = (contour_Ps > contour_Ps_eq)
        
            growth_rates[mask] = pTKinetics.growth_rate_interface_P2(contour_Ps[mask], contour_Ts[mask],\
                                                                contour_Ps_eq[mask], contour_Ts_eq[mask], Coh)
            contents_wl_mo = MO_Vfraction_classic(growth_rates, contour_ts, d_ol)
        elif MO_method == "kinetics":
            # compute metastable contents along the profile
            contents_wl_mo = np.zeros(contour_Ps.size)
            for i_p in range(contour_Ps.size-1):
                # parse variables:
                # P, T
                # t0, t1 - start and end of the time step
                P = contour_Ps[i_p]
                T = contour_Ts[i_p]
                t0 = contour_ts[i_p]
                t1 = contour_ts[i_p+1]
                Mo_Kinetics.set_kinetics_fixed(P, T, Coh)

                # solve the ODEs
                if i_p == 0:
                    _initial = None
                else:
                    _initial = results[-1, :]
                results = Mo_Kinetics.solve(P, T, t0, t1, n_t, n_span, initial=_initial)
                # results = Mo_Kinetics.solve(P, T, 0, t1, n_t, n_span)
                contents_wl_mo[i_p+1] = results[-1, 5]
        else:
            raise NotImplementedError
            
        # append to a solution
        solution_Xs = np.concatenate([solution_Xs, contour_Xs])
        solution_Ys = np.concatenate([solution_Ys, contour_Ys])
        solution_wl_eq = np.concatenate([solution_wl_eq, contents_wl_eq])
        solution_wl_mo = np.concatenate([solution_wl_mo, contents_wl_mo])
        contour_traces[i_prof] += [contour_Xs, contour_Ys, contents_wl_eq, contents_wl_mo]

    # interpolate solution to a grid
    xi = np.linspace(solution_Xs.min(), solution_Xs.max(), 200)
    yi = np.linspace(solution_Ys.min(), solution_Ys.max(), 200)


    solution_X_grid, solution_Y_grid = np.meshgrid(xi, yi)

    solution_distance_grid = distances_to_curve(Xs + trench_x, y_max - depths, solution_X_grid.ravel(), solution_Y_grid.ravel())
    solution_distance_grid = solution_distance_grid.reshape(solution_X_grid.shape)
    mask_slab_grid = (solution_distance_grid >= -5e3) & (solution_distance_grid <= 100e3)  # Slab internal region: -5 km to 100 km

    # Interpolate the scattered data
    solution_wl_mo_grid = griddata(
        (solution_Xs, solution_Ys),
        solution_wl_mo,
        (solution_X_grid, solution_Y_grid),
        method='linear'   # or 'cubic' for even smoother
    )

    solution_wl_mo_grid[~mask_slab_grid] = np.nan # set nan value for value out of the slab

In [None]:
if is_run_aspect:

    # Initialize plots

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    import matplotlib.ticker as ticker
    from matplotlib import gridspec

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 1000.0)
    x_tick_interval = 250   # tick interval along x
    y_lim = (0.0, 1000.0)
    y_tick_interval = 250  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })



    # Plot equilibrium & metastable phase transition
    fig = plt.figure(figsize=(8, 12), tight_layout=True)
    gs = gridspec.GridSpec(2, 1)

    # Equilibrium phases outside of the slab
    # Metastable phases inside the slab
    #   1. use the scatter plot from calculated points
    #   2. use the contourf plot from the grid data
    ax = fig.add_subplot(gs[0, 0])

    color_map3 = ax.contourf(xv/1e3, (y_extent - yv)/1e3, contents_wl_eq_grid, levels=100, cmap="viridis", vmin=0.0, vmax=1.0)
    cbar3 = fig.colorbar(color_map3, ax=ax, label="Wd contents")
    cbar3.set_ticks([0.0, 0.5, 1.0]) # colorbar ticks options
    minor_locator = ticker.MultipleLocator(0.1)  # This gives 4 minor ticks between major ticks spaced by 0.5
    cbar3.ax.yaxis.set_minor_locator(minor_locator)

    ax.plot((Xs + trench_x)/1e3, depths/1e3, "--k") # slab surface
    ax.plot(contour_traces[-1][0]/1e3, (y_extent - contour_traces[-1][1])/1e3, "--k") # slab surface
    # contours = ax.contour(xv, yv, distance_grid, levels=[0, 100e3], colors=['black', 'black'], linewidths=2) # curve of the slab

    contours = ax.contour(
        xv/1e3, (y_extent - yv)/1e3, T_grid-273.15, levels=np.arange(100.0, 1473.15 + 100.0, 200.0), colors="black", linewidths=0.5
    ) # temperature contours
    ax.clabel(contours, inline=True, fontsize=8, fmt="%.1f")  # Add labels to the contours

    # ax.scatter(solution_Xs/1e3, (y_extent - solution_Ys)/1e3, c=solution_wl_mo, cmap='viridis', s=20, vmin=0.0, vmax=1.0) # slab surface
    cset = ax.contourf(solution_X_grid/1e3, (y_extent - solution_Y_grid)/1e3, solution_wl_mo_grid, levels=50, cmap='viridis', vmin=0.0, vmax=1.0)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.invert_yaxis()

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.set_aspect("equal", adjustable="box")

    # Equilibrium & metastable phases, zoom in
    ax1 = fig.add_subplot(gs[1, 0])

    color_map1 = ax1.contourf(xv/1e3, (y_extent - yv)/1e3, contents_wl_eq_grid, levels=200, cmap="viridis", vmin=0.0, vmax=1.0)
    cbar1 = fig.colorbar(color_map1, ax=ax1, label="Wd contents")
    cbar1.set_ticks([0.0, 0.5, 1.0]) # colorbar ticks options
    minor_locator = ticker.MultipleLocator(0.1)  # This gives 4 minor ticks between major ticks spaced by 0.5
    cbar1.ax.yaxis.set_minor_locator(minor_locator)

    ax1.plot((Xs + trench_x)/1e3, depths/1e3, "--k") # slab surface
    ax1.plot(contour_traces[-1][0]/1e3, (y_extent - contour_traces[-1][1])/1e3, "--k") # slab surface

    # ax1.scatter(solution_Xs/1e3, (y_extent - solution_Ys)/1e3, c=solution_wl_mo, cmap='viridis', s=20, vmin=0.0, vmax=1.0) # slab surface
    ax1.contourf(solution_X_grid/1e3, (y_extent - solution_Y_grid)/1e3, solution_wl_mo_grid, levels=50, cmap='viridis', vmin=0.0, vmax=1.0)
    cset1 = ax1.contour(solution_X_grid/1e3, (y_extent - solution_Y_grid)/1e3, solution_wl_mo_grid, levels=(0.5,))

    # x_center = np.interp(500e3, y_max - depths, Xs + trench_x) # find the pin point center, issue: interpolation doesn't work
    x_center = 700

    ax1.set_xlim([x_center-200, x_center+200])
    ax1.set_ylim([300, 700])

    contours = ax1.contour(
        xv/1e3, (y_extent - yv)/1e3, T_grid-273.15, levels=np.arange(100.0, 1473.15 + 100.0, 200.0), colors="black", linewidths=0.5
    ) # temperature contours
    ax1.clabel(contours, inline=True, fontsize=8, fmt="%.1f")  # Add labels to the contours

    x_tick_interval1 = 100.0
    y_tick_interval1 = 100.0
    ax1.xaxis.set_major_locator(MultipleLocator(x_tick_interval1))
    ax1.xaxis.set_minor_locator(MultipleLocator(x_tick_interval1/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(y_tick_interval1))
    ax1.yaxis.set_minor_locator(MultipleLocator(y_tick_interval1/(n_minor_ticks+1)))

    ax1.invert_yaxis()

    ax1.set_xlabel("X (km)")
    ax1.set_ylabel("Depth (km)")

    ax1.set_aspect("equal", adjustable="box")



    # save figure
    pdf_path = os.path.join(output_dir, "wd_contents_%s.pdf" % MO_method)
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

    plt.show()

    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Reset rcParams to defaults
    rcdefaults()

In [None]:
# interpolate to regular grid
# Issue: have to adjust spacing and distance_threshold, otherwise has nan values

# spacing = 5e3
# distance_threshold = 5e3
# grid_x, grid_y, grid_wl_mo = interpolate_grid_from_points(solution_Xs, solution_Ys, solution_wl_mo, spacing, distance_threshold)


# # Plot distance to slab surface
# fig = plt.figure(figsize=(8, 12), tight_layout=True)
# gs = gridspec.GridSpec(2, 1)

# # Equilibrium phases
# ax = fig.add_subplot(gs[0, 0])

# # color_map3 = ax.contourf(xv, yv, distance_grid/1e3, levels=100, cmap="cividis")
# # cbar3 = fig.colorbar(color_map3, ax=ax, label="Distance to surface (km)")

# color_map_is_nan = ax.scatter(grid_x, grid_y, c=np.isnan(grid_wl_mo), cmap='viridis', s=20) # slab surface
# cbar_is_nan = fig.colorbar(color_map_is_nan, ax=ax, label="nan values")

# ax.set_xlim([0.0, x_max])
# ax.set_ylim([0.0, y_max])

# ax.set_xlabel("X")
# ax.set_ylabel("Y")

# ax.set_aspect("equal", adjustable="box")

In [None]:
if is_run_aspect:

    wd_n_profile = np.zeros([depths.size, len(contour_traces)])

    for i, depth in enumerate(depths):
        for j, contour_trace in enumerate(contour_traces):
            wd_n_profile[i, j] = contour_trace[3][i]

    # show cross-section at a query depth
    query_depth_array = [400e3, 450e3, 500e3]

    fig, ax = plt.subplots(figsize=(5, 5))

    for i_q, query_depth in enumerate(query_depth_array):
        i = np.argmin(np.abs(depths - query_depth))
        depth_q = depths[i]
        ax.plot(contour_distance_array/1e3, wd_n_profile[i, :], label="depth %.1f km" % (query_depth/1e3))

    ax.set_xlabel("Distance to Slab Surface (km)")
    ax.set_ylabel("Wd Contnents")

    ax.legend()

    ax.grid()

    # save figure
    pdf_path = os.path.join(output_dir, "MO_wedge_profile1.pdf")
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

In [None]:
if is_run_aspect:
    # show cross-section at a query depth
    wd_threshold = 0.5
    query_depth_array = np.arange(300e3, 700e3, 10e3)

    get_depth_array = np.zeros(query_depth_array.shape)
    get_distance_array = np.full([query_depth_array.size , 2], np.nan)
    for i_q, query_depth in enumerate(query_depth_array):
        i = np.argmin(np.abs(depths - query_depth))
        get_depth_array[i_q] = depths[i]
        indices_less = np.where(wd_n_profile[i, :] < wd_threshold)[0]
        if indices_less.size > 0:
            get_distance_array[i_q, 0], get_distance_array[i_q, 1] = \
                contour_distance_array[indices_less[0]], contour_distance_array[indices_less[-1]]

    fig = plt.figure(figsize=(10, 10), tight_layout=True)
    gs = gridspec.GridSpec(1, 2)

    # wedge shape
    ax = fig.add_subplot(gs[0, 0])

    ax.plot(get_distance_array[:, 0]/1e3, get_depth_array/1e3, color="b", label="Olivine Wedge")     
    ax.plot(get_distance_array[:, 1]/1e3, get_depth_array/1e3, color="b")

    ax.set_xlim([0, contour_distance_array[-1]/1e3])
    ax.set_ylim([300.0, 600.0])

    ax.set_xlabel("Distance to Slab Surface (km)")
    ax.set_ylabel("Depth (km)")

    ax.grid()

    ax.invert_yaxis()

    ax.legend()

    ax.set_aspect('equal')

    # wedge widge
    ax1 = fig.add_subplot(gs[0, 1])

    ax1.plot((get_distance_array[:, 1] - get_distance_array[:, 0])/1e3, get_depth_array/1e3, color="c", label="Wedge Width")     

    ax1.set_xlim([0, contour_distance_array[-1]/1e3])
    ax1.set_ylim([300.0, 600.0])

    ax1.set_xlabel("Wd Width (km)")
    ax1.set_ylabel("Depth (km)")

    ax1.grid()

    ax1.invert_yaxis()

    ax1.legend()

    ax1.set_aspect('equal')

    # save figure
    pdf_path = os.path.join(output_dir, "MO_wedge_profile_wd%.2f.pdf" % (wd_threshold))
    fig.savefig(pdf_path)
    print("Saved figure %s" % pdf_path)

# Grain Size Evolution Model

## ASPECT test grain_size_growth

### Create case

In [None]:
create_and_run_ggrowth_case = False

if create_and_run_ggrowth_case:

    from shutil import rmtree
    from hamageolib.utils.dealii_param_parser import parse_parameters_to_dict, save_parameters_from_dict
    from hamageolib.utils.world_builder_file_parser import find_feature_by_name, update_or_add_feature

    # Case options
    create_and_run_ggrowth_case_solve = True # whether to run or not
    parent_dir = "/mnt/lochy/ASPECT_DATA/MOW/mow_tests1"
    end_time = 1e6 # yr
    maximum_timestep = 1e4 # yr 
    visualization_timestep = 1e4 # yr 

    adiabatic_surfT = 1000 # default - 1600.0 K

    Vx_m_yr = 0.1

    use_kinetics = "wadleyite" # use kinetic relations for wd
    initial_grain_size = 1e-8 # default - 8e-5 

    # Growth kinetics
    if use_kinetics == "wadleyite":
        growth_E = 6.62e5   # Eg (J/mol)
        growth_exponent = 3   # pg
        growth_rate_constant = 3.02e-4  # k0 (m^pg / s)
        # growth_geometric_constant = 3.0     # c
        # growth_work_fraction_for_boundary_area_change = 0.1     # lambda
        # growth_average_specific_grain_boundary_area_energy = 1.0    # gamma
    elif use_kinetics == "default":
        growth_E = 3e5   # Eg (J/mol)
        growth_exponent = 10   # pg
        growth_rate_constant = 4e-45  # k0 (m^pg / s)
    else:
        raise NotImplementedError


    # aspect directory and executable
    aspect_dir = "/home/lochy/Softwares/aspect" # change this to your installed location of aspect
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect")

    # template path of the test
    prm_template_path = os.path.join(aspect_dir, "tests/grain_size_growth.prm")

    # create directories
    case_dir = os.path.join(parent_dir, "grain_size_growth_%s_ig%.2e_adT%.2f" % (use_kinetics, initial_grain_size, adiabatic_surfT))
    if not os.path.isdir(case_dir):
        os.mkdir(case_dir)
    img_dir = os.path.join(case_dir, "img")
    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)

    # Modify the template
    # Also read important parameters like the size of the model

    with open(prm_template_path, 'r') as file:
        params_dict = parse_parameters_to_dict(file)

    params_dict["Output directory"] = os.path.join(case_dir, "output")

    # Add maximum timestep and End time
    params_dict["End time"] = str(end_time) # yr
    params_dict["Maximum time step"] = str(maximum_timestep) # yr

    # Assign adiabatic surface temperature
    params_dict["Adiabatic surface temperature"] = str(adiabatic_surfT)

    # Assign initial grain size
    params_dict["Initial composition model"]["Function"]["Function expression"] =\
        "if(z<50000,%.6e,%.6e)" % (initial_grain_size, initial_grain_size*(1+1e-5))

    # Assign growth kinetics
    grain_size_model = params_dict["Material model"]["Grain size model"]

    grain_size_model["Grain growth activation energy"] = str(growth_E)
    grain_size_model["Grain growth rate constant"] = str(growth_rate_constant)
    grain_size_model["Grain growth exponent"] = str(growth_exponent)
    grain_size_model["Minimum grain size"] = str(1e-9)

    params_dict["Material model"]["Grain size model"] = grain_size_model

    # Add outputs
    params_dict["Postprocess"]["List of postprocessors"] += ", visualization"
    params_dict["Postprocess"]["Visualization"] = {
        "List of output variables": "material properties, named additional outputs, nonadiabatic pressure, strain rate, stress, heating",
        "Output format": "vtu",
        "Time between graphical output": str(visualization_timestep)
    }

    # Write to a prm file in the new case directory
    prm_path = os.path.join(case_dir, "case.prm")

    with open(prm_path, 'w') as output_file:
        save_parameters_from_dict(output_file, params_dict)

    assert(os.path.isfile(prm_path))

    print("Created case in %s" % (case_dir))

### Run

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_solve:

    import subprocess

    # Run the ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr.strip() == "")

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_solve:
    print(stderr)

### Post-process

#### Choose one step to plot

In [None]:
create_and_run_ggrowth_case_plot_one_step = True

if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_one_step:

    plot_time = 0.0 # yr
    vtu_timestep = int(plot_time / visualization_timestep)

    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator

    pvtu_file = os.path.join(case_dir, "output", "solution", "solution-%05d.pvtu" % vtu_timestep)
    assert(os.path.isfile(pvtu_file))

    # Read the pvtu file
    start = time.time()

    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    end = time.time()
    print("Initiating reader takes %.2e s" % (end - start))
    start = end

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    n_points = grid.GetNumberOfPoints() # Number of points and cells
    n_cells = grid.GetNumberOfCells()

    end = time.time()
    print("Reading files takes %.2e s" % (end - start))
    print(f"\tNumber of points: {n_points}")
    print(f"\tNumber of cells: {n_cells}")
    print("\tAvailable point data fields:")
    for i in range(point_data.GetNumberOfArrays()):
        # Field names in point data
        name = point_data.GetArrayName(i)
        print(f"\t  - {name}")
    start = end

    # Convert data to numpy array
    # Get coordinates (points)
    # Get field "T"

    vtk_points = grid.GetPoints().GetData()
    points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    # Initialize dictionary for interpolators
    interpolators = {}

    # Loop over all arrays in point data
    num_arrays = point_data.GetNumberOfArrays()
    for i in range(num_arrays):
        array_name = point_data.GetArrayName(i)
        vtk_array = point_data.GetArray(i)
        
        if vtk_array is None:
            print(f"Warning: Array {array_name} is None, skipping.")
            continue
        
        # Convert VTK array to NumPy
        np_array = vtk_to_numpy(vtk_array)
        
        # Create interpolator and add to dict
        interpolators[array_name] = LinearNDInterpolator(points_2d, np_array, fill_value=np.nan)

    # Calculate resolution for each cell or point in the grid
    resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2e s" % (end - start))
    start = end

    # Create interpolators for temperature, pressure, and resolution
    interpolators["resolution"] = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution

    end = time.time()
    print("Construct linear ND interpolator takes %.2e s" % (end - start))
    start = end

##### Create grid to plot

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_one_step:

    start = time.time()

    # Define the interval for the grid (in meters)
    interval = 0.1e3

    # Determine the bounding box of the 2D points
    x_min, y_min = np.min(points_2d, axis=0)
    x_max, y_max = np.max(points_2d, axis=0)

    # Define a regular grid within the bounding box
    # allow a little different in interval in x
    # and y axis, thereform making the two dimensions
    # unequal to make fewer mistakes ...
    xs = np.arange(x_min, x_max, interval*0.99)
    ys = np.arange(y_min, y_max, interval*1.01)
    x_grid, y_grid = np.meshgrid(xs, ys, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([x_grid.ravel(), y_grid.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolators["T"](grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolators["p"](grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(x_grid.shape)  # Reshape back to the grid
    
    # Interpolate grain size values onto the regular grid
    grain_size_grid = interpolators["grain_size"](grid_points_2d)  # Use the NearestNDInterpolator
    grain_size_grid = grain_size_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolators["resolution"](grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(x_grid.shape)

    # Interpolate velocities onto the regular grid
    velocity_grid = interpolators["velocity"](grid_points_2d)
    velocity_grid = velocity_grid.reshape([x_grid.shape[0], x_grid.shape[1],3])
    vx_grid = velocity_grid[:, :, 0]
    vy_grid = velocity_grid[:, :, 1]

    end = time.time()
    print("Interpolating to regular grid takes %.2e s" % (end - start))
    print("\tgrid shape: (x axis, y axis): ", x_grid.shape)
    start = end

##### Generate plots

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_one_step:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 100e3)
    x_tick_interval = 25e3   # tick interval along x
    y_lim = (0.0, 100e3)
    y_tick_interval = 25e3  # tick interval along y

    resolution_lim = (0.0, 10e3) # resolution
    resolution_level = 50  # number of levels in contourf plot
    resolution_tick_interval = 2.5e3  # tick interval along v

    T_lim = (0.0, 2000) # T
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 500 # tick interval along v

    P_lim = (-1e8, 1e8) # P
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 5e7  # tick interval along v

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure with a 2x2 grid layout
    fig = plt.figure(figsize=(12, 10), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)

    # Plot the mesh resolution
    ax = fig.add_subplot(gs[0, 0])

    levels = np.linspace(resolution_lim[0], resolution_lim[1], resolution_level)
    ticks=np.arange(resolution_lim[0], resolution_lim[1], resolution_tick_interval)

    color_map = ax.contourf(x_grid, y_grid, resolutions_grid,  vmin=resolution_lim[0], vmax=resolution_lim[1], levels=levels, cmap="plasma_r")  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="Resolution")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot T
    ax = fig.add_subplot(gs[0, 1])

    levels = np.linspace(T_lim[0], T_lim[1], T_level)
    ticks=np.arange(T_lim[0], T_lim[1], T_tick_interval)

    color_map = ax.contourf(x_grid, y_grid, T_grid,  vmin=T_lim[0], vmax=T_lim[1], levels=levels, cmap=ccm.lapaz)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="T")  # Add colorbar
    cbar.set_ticks(ticks)

    step = 200  # plot every 5th vector
    qv = ax.quiver(x_grid[::step, ::step], y_grid[::step, ::step],\
                vx_grid[::step, ::step], vy_grid[::step, ::step],\
                    scale=1, width=0.004, color='black')
    ax.quiverkey(qv, X=0.85, Y=-0.1, U=0.1, label='10 cm/yr', labelpos='E')

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot P
    ax = fig.add_subplot(gs[1, 0])

    levels = np.linspace(P_lim[0], P_lim[1], P_level)
    ticks=np.arange(P_lim[0], P_lim[1], P_tick_interval)

    color_map = ax.contourf(x_grid, y_grid, P_grid,  vmin=P_lim[0], vmax=P_lim[1], levels=levels, cmap=ccm.tokyo_r)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="P")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot grain size
    ax = fig.add_subplot(gs[1, 1])

    grain_size_min = -4.5
    grain_size_max = -4
    grain_size_tick_interval = 0.1
    levels = np.linspace(grain_size_min, grain_size_max, 50)
    ticks=np.arange(grain_size_min, grain_size_max*0.9999999, grain_size_tick_interval)

    color_map = ax.contourf(x_grid, y_grid, np.log10(grain_size_grid), vmin=grain_size_min, vmax=grain_size_max, levels=levels, cmap=ccm.tokyo) 
    cbar = fig.colorbar(color_map, ax=ax, label="grain size")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # save output
    fig_path = os.path.join(img_dir, "visualization_%.2fMa" % (plot_time/1e6))
    fig.savefig(fig_path + ".png")
    print("Saved figure %s" % (fig_path + ".png"))
    fig.savefig(fig_path + ".pdf")
    print("Saved figure %s" % (fig_path + ".pdf"))


    # Reset rcParams to defaults
    rcdefaults()

#### Loop the timesteps

##### Create the interpolators

In [None]:
create_and_run_ggrowth_case_plot_loop_steps = True

if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_loop_steps:
    
    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator


    plot_time_interval = 1e4 # yr

    plot_times = np.arange(0.0, end_time*1.0000001, plot_time_interval)

    interpolator_array = []
    
    start = time.time()

    for i, plot_time in enumerate(plot_times):
        
        vtu_timestep = int(plot_time / visualization_timestep)

        pvtu_file = os.path.join(case_dir, "output", "solution", "solution-%05d.pvtu" % vtu_timestep)
        assert(os.path.isfile(pvtu_file))

        # Read the pvtu file

        reader = vtk.vtkXMLPUnstructuredGridReader()
        reader.SetFileName(pvtu_file)
        reader.Update()

        # Get the output data from the reader
        grid = reader.GetOutput()  # Access the unstructured grid
        data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
        points = grid.GetPoints()  # Extract the points (coordinates)
        cells = grid.GetCells()  # Extract the cell connectivity information
        point_data = data_set.GetPointData()  # Access point-wise data

        n_points = grid.GetNumberOfPoints() # Number of points and cells
        n_cells = grid.GetNumberOfCells()

        if i == 0:
            print("Data in file:")
            print(f"\tNumber of points: {n_points}")
            print(f"\tNumber of cells: {n_cells}")
            print("\tAvailable point data fields:")
            for i in range(point_data.GetNumberOfArrays()):
                # Field names in point data
                name = point_data.GetArrayName(i)
                print(f"\t  - {name}")

        # Convert data to numpy array
        # Get coordinates (points)
        # Get field "T"

        vtk_points = grid.GetPoints().GetData()
        points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
        points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

        # Initialize dictionary for interpolators
        interpolators = {}

        # Loop over all arrays in point data
        num_arrays = point_data.GetNumberOfArrays()
        for i in range(num_arrays):
            array_name = point_data.GetArrayName(i)
            vtk_array = point_data.GetArray(i)
            
            if vtk_array is None:
                print(f"Warning: Array {array_name} is None, skipping.")
                continue
            
            # Convert VTK array to NumPy
            np_array = vtk_to_numpy(vtk_array)
            
            # Create interpolator and add to dict
            interpolators[array_name] = LinearNDInterpolator(points_2d, np_array, fill_value=np.nan)

        # Calculate resolution for each cell or point in the grid
        resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

        # Create interpolators for temperature, pressure, and resolution
        interpolators["resolution"] = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution

        interpolator_array.append(interpolators)
    
    end = time.time()
    print("Reading files takes %.2e s" % (end - start))

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_loop_steps:

    grain_sizes = []
    y = 5e4 # m
    for i, plot_time in enumerate(plot_times):
        interpolators = interpolator_array[i]
        x = plot_time * Vx_m_yr

        points_2d = np.array([x, y]) 
        grain_size = interpolators["grain_size"](points_2d)

        grain_sizes.append(grain_size)

In [None]:
if create_and_run_ggrowth_case and create_and_run_ggrowth_case_plot_loop_steps:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    fig, ax = plt.subplots()

    x_lim = [0.0, 1e6]

    log_grain_sizes = np.log10(grain_sizes)
    y_min = np.floor(np.min(log_grain_sizes)/0.1) * 0.1; y_max = np.ceil(np.max(log_grain_sizes)/0.1) * 0.1
    y_lim = [y_min, y_max]
    x_tick_interval = 2.5e5
    y_tick_interval = 0.1
    ax.plot(plot_times, np.log10(grain_sizes))
    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.set_xlabel("Time (yr)")
    ax.set_ylabel("log10(Grain Size) (m)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    fig_path = os.path.join(img_dir, "time_analysis")
    fig.savefig(fig_path + ".png")
    print("Saved figure %s" % (fig_path + ".png"))
    fig.savefig(fig_path + ".pdf")
    print("Saved figure %s" % (fig_path + ".pdf"))

    # Reset rcParams to defaults
    rcdefaults()

## Analyze the synthetic equation

### Initialize

In [None]:
analyze_grain_growth_synthetic = True
if analyze_grain_growth_synthetic:
    from hamageolib.core.GrainSize import GrainGrowthModel, GrainGrowthParams

    use_kinetics = "default"

    if use_kinetics == "wadleyite":
        params = GrainGrowthParams(
            grain_growth_rate_constant=3.02e-4,
            m=3,
            grain_growth_activation_energy=6.62e5,
            grain_growth_activation_volume=0.0,
        )
    elif use_kinetics == "default":
        params = GrainGrowthParams(
            grain_growth_rate_constant=4e-45,
            m=10,
            grain_growth_activation_energy=3e5,
            grain_growth_activation_volume=0.0,
        )
    else:
        raise NotImplementedError

    # todo_gz 
    gModel = GrainGrowthModel(params=params)

    img_dir = os.path.join(results_dir, "analyze_grain_growth_synthetic")
    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
    
    year = 365 * 24 * 3600.0


### Solve analytic relation for grain growth vs time

In [None]:
if analyze_grain_growth_synthetic:

    P = 14e9 # Pa
    T = 1600 # K, default-1600.0
    initial_grain_size = 8e-5
    t_max = 1e6 * year

    ts = np.linspace(0.0, 1e6*year, 1000)

    grain_sizes = gModel.grain_size_at_time(initial_grain_size, ts, P, T)

In [None]:
if analyze_grain_growth_synthetic:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })
    
    log_grain_sizes = np.log10(grain_sizes)

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(ts/year, log_grain_sizes)

    x_lim = [-5e4, t_max/year]
    x_tick_interval = 5e5
    y_min = np.floor(np.min(log_grain_sizes)/0.1) * 0.1; y_max = np.ceil(np.max(log_grain_sizes)/0.1) * 0.1
    y_lim = [y_min, y_max]
    y_tick_interval = 0.1
    n_minor_ticks = 4
    
    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.set_xlabel("Time (yr)")
    ax.set_ylabel("log10(Grain Size) (m)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    fig_path = os.path.join(img_dir, "time_analysis_%s_ig%.2e_T%.2f" % (use_kinetics, initial_grain_size, T))
    fig.savefig(fig_path + ".png")
    print("Saved figure %s" % (fig_path + ".png"))
    fig.savefig(fig_path + ".pdf")
    print("Saved figure %s" % (fig_path + ".pdf"))

    # Reset rcParams to defaults
    rcdefaults()

### Solving the grain size at constant time with different T

In [None]:
if analyze_grain_growth_synthetic:

    P = 14e9 # Pa
    initial_grain_size = 1e-8
    t_max_array = [1e4*year, 1e5*year, 1e6 * year, 1e7 * year]

    Ts = np.linspace(1000.0, 2000.0, 1000) # K

    grain_size_array = []
    for i, t_max in enumerate(t_max_array):
        grain_sizes = gModel.grain_size_at_time(initial_grain_size, t_max, P, Ts)
        grain_size_array.append(grain_sizes)

In [None]:
if analyze_grain_growth_synthetic:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    fig, ax = plt.subplots(figsize=(10, 6))
    
    min_log_grain_sizes = None; max_log_grain_sizes = None 
    for i, t_max in enumerate(t_max_array):
        grain_sizes = grain_size_array[i]
        log_grain_sizes = np.log10(grain_sizes)

        if i == 0:
            min_log_grain_sizes = np.min(log_grain_sizes)
            max_log_grain_sizes = np.max(log_grain_sizes)
        else:
            min_log_grain_sizes = min(min_log_grain_sizes, np.min(log_grain_sizes))
            max_log_grain_sizes = max(max_log_grain_sizes, np.max(log_grain_sizes))

        ax.plot(Ts, log_grain_sizes, color=default_colors[i], label="%.2e yr" % (t_max/year))

    x_lim = [np.min(Ts), np.max(Ts)]
    x_tick_interval = 500.0
    y_min = np.floor(np.min(log_grain_sizes)/0.1) * 0.1; y_max = np.ceil(np.max(log_grain_sizes)/0.1) * 0.1
    y_lim = [min_log_grain_sizes, max_log_grain_sizes]
    y_tick_interval = 0.5
    n_minor_ticks = 4
    
    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.grid()
    ax.legend()

    ax.set_xlabel("T (K)")
    ax.set_ylabel("log10(Grain Size) (m)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    fig_path = os.path.join(img_dir, "T_analysis_%s_ig%.2e" % (use_kinetics, initial_grain_size))
    fig.savefig(fig_path + ".png")
    print("Saved figure %s" % (fig_path + ".png"))
    fig.savefig(fig_path + ".pdf")
    print("Saved figure %s" % (fig_path + ".pdf"))

    # Reset rcParams to defaults
    rcdefaults()

# ASPECT implementation

## Derive the reference pressure profile

# Run ASPECT tests

(Keep updated to the "ASPECT Implementation" section in supplementary material)


## Initial condition test (TwoDSubduction_metastable_initial.prm)

This test serves to check the initial condition of the "metastable" composition.

In [None]:
is_run_aspect_tests_initial = False

if is_run_aspect_tests_initial:

    test_composition = "spharz"  # background or spharz
    n_repetition = 50  # 50 - original, 200 - high to make plots

    aspect_dir = "/home/lochy/Softwares/aspect"
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect")
    prm_template_path = os.path.join(aspect_dir, "tests", "TwoDSubduction_metastable_initial.prm")

    assert(os.path.isfile(aspect_executable))
    assert(os.path.isfile(prm_template_path))

    # assign another directory to run the case
    case_root_dir = os.path.join(root_path, "dtemp") 

    case_dir = os.path.join("/mnt/lochz/ASPECT_DATA/TwoDSubduction/test_cases", "TwoDSubduction_metastable_initial")  # New directory to run the case
    if not os.path.isdir(case_dir):
        os.mkdir(case_dir)

    output_dirname = "output_initial_%d" % n_repetition  # output directory

In [None]:
if is_run_aspect_tests_initial:

    from hamageolib.utils.dealii_param_parser import parse_parameters_to_dict, save_parameters_from_dict
    from hamageolib.utils.world_builder_file_parser import find_feature_by_name, update_or_add_feature


    # Modify the template
    # Also read important parameters like the size of the model

    with open(prm_template_path, 'r') as file:
        params_dict = parse_parameters_to_dict(file)


    params_dict["Output directory"] = os.path.join(case_dir, output_dirname)

    params_dict["Additional shared libraries"] = "$ASPECT_SOURCE_DIR/build_master_TwoD_rebase/subduction_temperature2d/libsubduction_temperature2d.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/prescribe_field/libprescribed_temperature.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/visco_plastic_TwoD/libvisco_plastic_TwoD.so"

    params_dict["Geometry model"]["Box"]["X repetitions"] = str(n_repetition)
    params_dict["Geometry model"]["Box"]["Y repetitions"] = str(n_repetition)
        
    if test_composition == "background":
        pass
    elif test_composition == "spharz":
        params_dict["Initial composition model"]["Function"]["Function expression"] = "0.0 ; 1.0; 0.0 ; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0"
    else:
        raise ValueError("test_composition must be background or spharz")
    
    params_dict["Postprocess"]["Visualization"]["Output format"] = "vtu"
    
    # Write to a prm file in the new case directory
    prm_path = os.path.join(case_dir, "case.prm")

    with open(prm_path, 'w') as output_file:
        save_parameters_from_dict(output_file, params_dict)

    assert(os.path.isfile(prm_path))

    print("Created case in %s" % (case_dir))

### Run

Use "subprocess.run" to run the case.

Capture the standard output and error streams

Check
  * if the expected line indicating wallclock time appears in the output.
  * There is no stderr output.

In [None]:
if is_run_aspect_tests_initial:

    # Run theb ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr == "")

## Post-process

In [None]:
if is_run_aspect_tests_initial:

    vtu_step = 0  # * 1ky

    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator

    pvtu_file = os.path.join(case_dir, output_dirname, "solution", "solution-%05d.pvtu" % vtu_step)
    assert(os.path.isfile(pvtu_file))

    # Read the pvtu file
    start = time.time()

    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    end = time.time()
    print("Initiating reader takes %.2e s" % (end - start))
    start = end

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    n_points = grid.GetNumberOfPoints() # Number of points and cells
    n_cells = grid.GetNumberOfCells()

    end = time.time()
    print("Reading files takes %.2e s" % (end - start))
    print(f"\tNumber of points: {n_points}")
    print(f"\tNumber of cells: {n_cells}")
    print("\tAvailable point data fields:")
    for i in range(point_data.GetNumberOfArrays()):
        # Field names in point data
        name = point_data.GetArrayName(i)
        print(f"\t  - {name}")
    start = end

    # Convert data to numpy array
    # Get coordinates (points)
    # Get field "T"
    vtk_points = grid.GetPoints().GetData()
    points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    # Initialize dictionary for interpolators
    interpolators = {}

    # Loop over all arrays in point data
    num_arrays = point_data.GetNumberOfArrays()
    for i in range(num_arrays):
        array_name = point_data.GetArrayName(i)
        vtk_array = point_data.GetArray(i)
        
        if vtk_array is None:
            print(f"Warning: Array {array_name} is None, skipping.")
            continue
        
        # Convert VTK array to NumPy
        np_array = vtk_to_numpy(vtk_array)
        
        # Create interpolator and add to dict
        interpolators[array_name] = LinearNDInterpolator(points_2d, np_array, fill_value=np.nan)

    # Calculate resolution for each cell or point in the grid
    resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2e s" % (end - start))
    start = end

    # Create interpolators for temperature, pressure, and resolution
    interpolators["resolution"] = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution

    end = time.time()
    print("Construct linear ND interpolator takes %.2e s" % (end - start))
    start = end

In [None]:
if is_run_aspect_tests_initial:
    
    start = time.time()

    # Define the interval for the grid (in meters)
    interval = 1000.0

    # Determine the bounding box of the 2D points
    x_min, y_min = np.min(points_2d, axis=0)
    x_max, y_max = np.max(points_2d, axis=0)

    # Define a regular grid within the bounding box
    # allow a little different in interval in x
    # and y axis, thereform making the two dimensions
    # unequal to make fewer mistakes ...
    xs = np.arange(x_min, x_max, interval*0.99)
    ys = np.arange(y_min, y_max, interval*1.01)
    x_grid, y_grid = np.meshgrid(xs, ys, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([x_grid.ravel(), y_grid.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolators["T"](grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolators["p"](grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolators["resolution"](grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(x_grid.shape)

    # Interpolate density onto the regular grid
    density_grid = interpolators["density"](grid_points_2d)
    density_grid = density_grid.reshape(x_grid.shape)


    # Interpolate metastable compositions onto the regular grid
    metastable_grid = interpolators["metastable"](grid_points_2d)
    metastable_grid = metastable_grid.reshape(x_grid.shape)
        
    end = time.time()
    print("Interpolating to regular grid takes %.2e s" % (end - start))
    print("\tgrid shape: (x axis, y axis): ", x_grid.shape)
    start = end

In [None]:
if is_run_aspect_tests_initial:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 800) # km
    x_tick_interval = 100   # tick interval along x
    y_lim = (0.0, 800) # km
    y_tick_interval = 100  # tick interval along y

    resolution_lim = (0.0, 100e3) # resolution
    resolution_level = 50  # number of levels in contourf plot
    resolution_tick_interval = 25e3  # tick interval along v

    T_lim = (0.0, 2000.0) # T
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 250.0  # tick interval along v

    P_lim = (np.min(P_grid), np.max(P_grid)) # P
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 1e9  # tick interval along P
    
    density_lim = (3000.0, 4000.0)
    density_level = 50  # number of levels in contourf plot
    density_tick_interval = 100.0  # tick interval along P

    metastable_lim = (-0.01, 1.01)
    metastable_level = 50
    metastable_tick_interval = 0.25

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure with a 2x2 grid layout
    fig = plt.figure(figsize=(12, 15), tight_layout=True)
    gs = gridspec.GridSpec(3, 2)

    # Plot the mesh resolution
    ax = fig.add_subplot(gs[0, 0])

    levels = np.linspace(resolution_lim[0], resolution_lim[1], resolution_level)
    ticks=np.arange(resolution_lim[0], resolution_lim[1], resolution_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, resolutions_grid,  vmin=resolution_lim[0], vmax=resolution_lim[1], levels=levels, cmap="plasma_r")  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="Resolution")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot T
    ax = fig.add_subplot(gs[0, 1])

    levels = np.linspace(T_lim[0], T_lim[1], T_level)
    ticks=np.arange(T_lim[0], T_lim[1], T_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, T_grid,  vmin=T_lim[0], vmax=T_lim[1], levels=levels, cmap=ccm.lapaz)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="T")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot P
    ax = fig.add_subplot(gs[1, 0])

    levels = np.linspace(P_lim[0], P_lim[1], P_level)
    ticks=np.arange(P_lim[0], P_lim[1], P_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, P_grid,  vmin=P_lim[0], vmax=P_lim[1], levels=levels, cmap=ccm.tokyo_r)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="P")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot density
    ax = fig.add_subplot(gs[1, 1])

    levels = np.linspace(density_lim[0], density_lim[1], density_level)
    ticks=np.arange(density_lim[0], density_lim[1], density_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, density_grid,  vmin=density_lim[0], vmax=density_lim[1], levels=levels, cmap=ccm.batlow)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="density")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    ax = fig.add_subplot(gs[2, 0])

    levels = np.linspace(metastable_lim[0], metastable_lim[1], metastable_level)
    ticks=np.arange(metastable_lim[0], metastable_lim[1], metastable_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, metastable_grid,  vmin=metastable_lim[0], vmax=metastable_lim[1], levels=levels, cmap="viridis")  # Metastable color map
    cbar = fig.colorbar(color_map, ax=ax, label="metastable")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # save figure
    ofile_base = os.path.join(case_dir, "metastable_initial_nrep_%d" % (n_repetition))
    ofile = ofile_base + ".png"
    fig.savefig(os.path.join(ofile))
    print("Saved figure %s" % ofile)
    ofile = ofile_base + ".pdf"
    fig.savefig(os.path.join(ofile))
    print("Saved figure %s" % ofile)

    # Reset rcParams to defaults
    rcdefaults()


## Diagram test (TwoDSubduction_metastable_reaction_1ky)

This tests make use of

    TwoDSubduction_metastable_reaction_1ky.prm

### Set up

Options

* Change the resolution. The test is compiled with low resolution, while here a high resolution result is needed to generate results comparable to the python code results.

* Here the set up allows running and plotting the results. Option for plotting only is:

    is_run_aspect_tests_1ky_solving = False

* For testing run time of the metastbale part, turn this option on and off and look at the run time

    with_metastable = True

* For testing the code of the metastable function, turn this option on and off to skip the computation

    reaction_metastable_trivial = False

In [None]:
is_run_aspect_tests_1ky = False
is_run_aspect_tests_1ky_solving = True

if is_run_aspect_tests_1ky:

    with_metastable = True
    with_grain_size_evolution = True
    reaction_metastable_trivial = False
    test_composition = "spharz"  # background or spharz
    n_repetition = 200  # 50 - original, 200 - high to make plots

    aspect_dir = "/home/lochy/Softwares/aspect"
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect")
    if with_grain_size_evolution:
        prm_template_path = os.path.join(aspect_dir, "tests", "TwoDSubduction_metastable_reaction_grain_size.prm")
        case_dir = os.path.join("/mnt/lochy/ASPECT_DATA/MOW/mow_tests2", "TwoDSubduction_metastable_reaction_grain_size")  # New directory to run the case
    else:
        prm_template_path = os.path.join(aspect_dir, "tests", "TwoDSubduction_metastable_reaction_1ky.prm")
        case_dir = os.path.join("/mnt/lochy/ASPECT_DATA/MOW/mow_tests2", "TwoDSubduction_metastable_reaction_1ky")  # New directory to run the case

    assert(os.path.isfile(aspect_executable))
    assert(os.path.isfile(prm_template_path))

    # assign another directory to run the case
    case_root_dir = os.path.join(root_path, "dtemp") 

    if not os.path.isdir(case_dir):
        os.mkdir(case_dir)

    if with_metastable:
        if reaction_metastable_trivial:
            output_dirname = "output_reaction_trivial_nrep_%d" % n_repetition  # output directory
        else:
            output_dirname = "output_nrep_%d" % n_repetition  # output directory
    else:
        output_dirname = "output_trivial_nrep_%d" % n_repetition  # output directory

In [None]:
if is_run_aspect_tests_1ky and is_run_aspect_tests_1ky_solving:

    from hamageolib.utils.dealii_param_parser import parse_parameters_to_dict, save_parameters_from_dict
    from hamageolib.utils.world_builder_file_parser import find_feature_by_name, update_or_add_feature


    # Modify the template
    # Also read important parameters like the size of the model

    with open(prm_template_path, 'r') as file:
        params_dict = parse_parameters_to_dict(file)



    params_dict["Output directory"] = os.path.join(case_dir, output_dirname)

    params_dict["Additional shared libraries"] = "$ASPECT_SOURCE_DIR/build_master_TwoD_rebase/subduction_temperature2d/libsubduction_temperature2d.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/prescribe_field/libprescribed_temperature.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/visco_plastic_TwoD/libvisco_plastic_TwoD.so"

    params_dict["Geometry model"]["Box"]["X repetitions"] = str(n_repetition)
    params_dict["Geometry model"]["Box"]["Y repetitions"] = str(n_repetition)
        
    if test_composition == "background":
        pass
    elif test_composition == "spharz":
        if with_grain_size_evolution:
            params_dict["Initial composition model"]["Function"]["Function expression"] = "0.0 ; 1.0; 0.0 ; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0"
        else:
            params_dict["Initial composition model"]["Function"]["Function expression"] = "0.0 ; 1.0; 0.0 ; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0"
    else:
        raise ValueError("test_composition must be background or spharz")

    if not with_metastable:
        params_dict["Compositional fields"] =  \
        {
            "Number of fields": "4", "Names of fields": "spcrust, spharz, opcrust, opharz",\
                "Compositional field methods" : "particles, particles, particles, particles",\
                    "Mapped particle properties": "spcrust: initial spcrust, spharz:initial spharz, opharz:initial opharz, opcrust: initial opcrust"
        }

        params_dict["Initial composition model"]["Function"]["Function expression"] = "0.0 ; 0.0; 0.0 ; 0.0"
        
        params_dict["Material model"]["Visco Plastic TwoD"]["Reaction metastable"] = "false"
        params_dict["Material model"]["Visco Plastic TwoD"]["Densities"] = "background: 3300.0|3394.4|3442.1|3453.2|3617.6|3691.5|3774.7|3929.1,\
                        spharz: 3235.0|3372.3|3441.7|3441.7|3680.8|3717.8|3759.4|3836.6,\
                        spcrust: 3000.0|3540.0|3613.0|3871.7,\
                        opcrust: 3000.0, opharz: 3235.0"

        params_dict["Material model"]["Visco Plastic TwoD"].pop("Metastable transition")
        # params_dict["Material model"]["Visco Plastic TwoD"].pop("Metastable transition comp")

        params_dict["Particles"]["List of particle properties"] = "initial composition"

    if reaction_metastable_trivial:
        params_dict["Material model"]["Visco Plastic TwoD"]["Reaction metastable trivial"] = "true"

    params_dict["Postprocess"]["Visualization"]["Output format"] = "vtu"



    # Write to a prm file in the new case directory
    prm_path = os.path.join(case_dir, "case.prm")

    with open(prm_path, 'w') as output_file:
        save_parameters_from_dict(output_file, params_dict)

    assert(os.path.isfile(prm_path))

    print("Created case in %s" % (case_dir))

### Run

Use "subprocess.run" to run the case.

Capture the standard output and error streams

Check
  * if the expected line indicating wallclock time appears in the output.
  * There is no stderr output.

In [None]:
if is_run_aspect_tests_1ky and is_run_aspect_tests_1ky_solving:

    # Run theb ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr == "")

In [None]:

if is_run_aspect_tests_1ky and is_run_aspect_tests_1ky_solving:
    print(stderr)

Save the stdout output to a separate file

In [None]:

if is_run_aspect_tests_1ky and is_run_aspect_tests_1ky_solving:

    std_file = os.path.join(case_dir, "stdout.txt")

    with open(std_file, "w") as fout:
        fout.write(stdout)

    print("Saved stdout outputs: %s" % std_file)

### Post-process

In [None]:
if is_run_aspect_tests_1ky:

    vtu_step = 1  # * 1ky

    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator

    pvtu_file = os.path.join(case_dir, output_dirname, "solution", "solution-%05d.pvtu" % vtu_step)
    assert(os.path.isfile(pvtu_file))

    # Read the pvtu file
    start = time.time()

    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    end = time.time()
    print("Initiating reader takes %.2e s" % (end - start))
    start = end

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    n_points = grid.GetNumberOfPoints() # Number of points and cells
    n_cells = grid.GetNumberOfCells()

    end = time.time()
    print("Reading files takes %.2e s" % (end - start))
    print(f"\tNumber of points: {n_points}")
    print(f"\tNumber of cells: {n_cells}")
    print("\tAvailable point data fields:")
    for i in range(point_data.GetNumberOfArrays()):
        # Field names in point data
        name = point_data.GetArrayName(i)
        print(f"\t  - {name}")
    start = end

    # Convert data to numpy array
    # Get coordinates (points)
    # Get field "T"
    vtk_points = grid.GetPoints().GetData()
    points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    # Initialize dictionary for interpolators
    interpolators = {}

    # Loop over all arrays in point data
    num_arrays = point_data.GetNumberOfArrays()
    for i in range(num_arrays):
        array_name = point_data.GetArrayName(i)
        vtk_array = point_data.GetArray(i)
        
        if vtk_array is None:
            print(f"Warning: Array {array_name} is None, skipping.")
            continue
        
        # Convert VTK array to NumPy
        np_array = vtk_to_numpy(vtk_array)
        
        # Create interpolator and add to dict
        interpolators[array_name] = LinearNDInterpolator(points_2d, np_array, fill_value=np.nan)

    # Calculate resolution for each cell or point in the grid
    resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2e s" % (end - start))
    start = end

    # Create interpolators for temperature, pressure, and resolution
    interpolators["resolution"] = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution

    end = time.time()
    print("Construct linear ND interpolator takes %.2e s" % (end - start))
    start = end

In [None]:
if is_run_aspect_tests_1ky:
    
    start = time.time()

    # Define the interval for the grid (in meters)
    interval = 1000.0

    # Determine the bounding box of the 2D points
    x_min, y_min = np.min(points_2d, axis=0)
    x_max, y_max = np.max(points_2d, axis=0)

    # Define a regular grid within the bounding box
    # allow a little different in interval in x
    # and y axis, thereform making the two dimensions
    # unequal to make fewer mistakes ...
    xs = np.arange(x_min, x_max, interval*0.99)
    ys = np.arange(y_min, y_max, interval*1.01)
    x_grid, y_grid = np.meshgrid(xs, ys, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([x_grid.ravel(), y_grid.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolators["T"](grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolators["p"](grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolators["resolution"](grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(x_grid.shape)

    # Interpolate density onto the regular grid
    density_grid = interpolators["density"](grid_points_2d)
    density_grid = density_grid.reshape(x_grid.shape)

    # Interpolate viscosity onto the regular grid
    viscosity_grid = interpolators["viscosity"](grid_points_2d)
    viscosity_grid = viscosity_grid.reshape(x_grid.shape)

    # Interpolate the grain size grid
    if with_grain_size_evolution:
        grain_density_grid = interpolators["meta_x0"](grid_points_2d)
        grain_density_grid = grain_density_grid.reshape(x_grid.shape)

        grain_size_grid = interpolators["meta_grain_size"](grid_points_2d)
        grain_size_grid = grain_size_grid.reshape(x_grid.shape)


    # Interpolate latent heat onto the regular grid
    # lheat_grid = interpolators["latent_heat"](grid_points_2d)
    # lheat_grid = lheat_grid.reshape(x_grid.shape)

    # Interpolate metastable compositions onto the regular grid
    if with_metastable:
        metastable_grid = interpolators["metastable"](grid_points_2d)
        metastable_grid = metastable_grid.reshape(x_grid.shape)
        
        metarate_grid = interpolators["meta_rate"](grid_points_2d)
        metarate_grid = metarate_grid.reshape(x_grid.shape)

    end = time.time()
    print("Interpolating to regular grid takes %.2e s" % (end - start))
    print("\tgrid shape: (x axis, y axis): ", x_grid.shape)
    start = end

In [None]:
if is_run_aspect_tests_1ky:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 800) # km
    x_tick_interval = 100   # tick interval along x
    y_lim = (0.0, 1200) # km
    y_tick_interval = 100  # tick interval along y

    resolution_lim = (0.0, 100e3) # resolution
    resolution_level = 50  # number of levels in contourf plot
    resolution_tick_interval = 25e3  # tick interval along v

    T_lim = (0.0, 2000.0) # T
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 250.0  # tick interval along v

    P_lim = (np.min(P_grid), np.max(P_grid)) # P
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 1e9  # tick interval along P
    
    density_lim = (3000.0, 4000.0)
    density_level = 50  # number of levels in contourf plot
    density_tick_interval = 100.0  # tick interval along P

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure with a 2x2 grid layout
    fig = plt.figure(figsize=(12, 25), tight_layout=True)
    gs = gridspec.GridSpec(5, 2)

    # Plot the mesh resolution
    ax = fig.add_subplot(gs[0, 0])

    levels = np.linspace(resolution_lim[0], resolution_lim[1], resolution_level)
    ticks=np.arange(resolution_lim[0], resolution_lim[1], resolution_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, resolutions_grid,  vmin=resolution_lim[0], vmax=resolution_lim[1], levels=levels, cmap="plasma_r")  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="Resolution")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot T
    ax = fig.add_subplot(gs[0, 1])

    levels = np.linspace(T_lim[0], T_lim[1], T_level)
    ticks=np.arange(T_lim[0], T_lim[1], T_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, T_grid,  vmin=T_lim[0], vmax=T_lim[1], levels=levels, cmap=ccm.lapaz)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="T")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot P
    ax = fig.add_subplot(gs[1, 0])

    levels = np.linspace(P_lim[0], P_lim[1], P_level)
    ticks=np.arange(P_lim[0], P_lim[1], P_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, P_grid,  vmin=P_lim[0], vmax=P_lim[1], levels=levels, cmap=ccm.tokyo_r)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="P")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot density
    ax = fig.add_subplot(gs[1, 1])

    levels = np.linspace(density_lim[0], density_lim[1], density_level)
    ticks=np.arange(density_lim[0], density_lim[1], density_tick_interval)

    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, density_grid,  vmin=density_lim[0], vmax=density_lim[1], levels=levels, cmap=ccm.batlow)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="density")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    # Plot heating term
    # ax = fig.add_subplot(gs[2, 0])

    # heating_lim = (-2e-3, 2e-3)
    # heating_level = 50
    # heating_tick_interval = 5e-4

    # levels = np.linspace(heating_lim[0], heating_lim[1], heating_level)
    # ticks=np.arange(heating_lim[0], heating_lim[1], heating_tick_interval) 
    
    # color_map = ax.contourf(x_grid/1e3, y_grid/1e3, lheat_grid, levels=levels)
    # cbar = fig.colorbar(color_map, ax=ax, label="latent heat", cmap=ccm.glasgow)  # Add colorbar
    # cbar.set_ticks(ticks)
    
    # ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    # ax.set_xlim(x_lim)
    # ax.set_ylim(y_lim)

    # ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    # ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    # ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    # ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    # ax.set_xlabel("X")
    # ax.set_ylabel("Y")
    
    # for spine in ax.spines.values():
    #     # Adjust spine thickness for this plot
    #     spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    # Plot meta rate
    ax = fig.add_subplot(gs[2, 1])

    metarate_lim = (0, 1e-11)
    # levels = np.linspace(metarate_lim[0], metarate_lim[1], 50)
    # metarate_tick_interval = 2.5e-12
    # ticks=np.arange(metarate_lim[0], metarate_lim[1], metarate_tick_interval)
    # levels = np.linspace(0.0, np.max(metarate_grid), 50)
    # ticks = np.linspace(0.0, np.max(metarate_grid), 10)
    color_map = ax.contourf(x_grid/1e3, y_grid/1e3, metarate_grid, cmap=ccm.buda) #, cmap=ccm.buda, vmin=0.0, vmax=1.0) #, level=levels) 
    cbar = fig.colorbar(color_map, ax=ax, label="meta rate")  # Add colorbar
    # cbar.set_ticks(ticks)
    
    ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    
    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # plot grain size
    # First plot the initial grain size
    # Then the grain size after growth
    if with_grain_size_evolution:
        from hamageolib.core.GrainSize import GrainGrowthModel, GrainGrowthParams
        year = 365 * 24 * 3600.0

        initial_grain_size_grid = (6.0 * metastable_grid / grain_density_grid / np.pi)**(1.0/3.0)
        grain_size_log_range = [-8.0, -1.0]
        ax = fig.add_subplot(gs[3, 0])
        levels = np.linspace(grain_size_log_range[0], grain_size_log_range[1], 50)
        ticks=np.arange(grain_size_log_range[0], grain_size_log_range[1]+0.1, 1.0)
        color_map = ax.contourf(x_grid/1e3, y_grid/1e3, np.log10(initial_grain_size_grid), cmap="viridis", levels=levels, extend="both") 
        cbar = fig.colorbar(color_map, ax=ax, label="log10(initial grain_size)")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        
        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


        ax = fig.add_subplot(gs[3, 1])
        color_map = ax.contourf(x_grid/1e3, y_grid/1e3, np.log10(grain_size_grid), cmap="viridis", levels=levels, extend="both")
        cbar = fig.colorbar(color_map, ax=ax, label="log10(grain_size)")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        
        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

        # todo_gz
        # Parameters for Wd grain growth
        params = GrainGrowthParams(
            grain_growth_rate_constant=3.02e-4,
            m=3,
            grain_growth_activation_energy=6.62e5,
            grain_growth_activation_volume=0.0,
        )

        gModel = GrainGrowthModel(params=params)

        time_scale = 1e6 * year
        synthetic_grain_size_grid = np.full(initial_grain_size_grid.shape, np.nan)
        mask = (initial_grain_size_grid > 0)
        synthetic_grain_size_grid[mask] = gModel.grain_size_at_time(initial_grain_size_grid[mask], time_scale, P_grid[mask], T_grid[mask])

        ax = fig.add_subplot(gs[4, 0])
        color_map = ax.contourf(x_grid/1e3, y_grid/1e3, np.log10(synthetic_grain_size_grid), cmap="viridis", levels=levels, extend="both")
        cbar = fig.colorbar(color_map, ax=ax, label="log10(grain_size)")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        
        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

        # plot viscosity
        ax = fig.add_subplot(gs[4, 1])

        levels = np.linspace(10.0, 30.0, 50)
        ticks=np.arange(10.0, 30.0+0.1, 1.0)
        color_map = ax.contourf(x_grid/1e3, y_grid/1e3, np.log10(viscosity_grid), cmap=ccm.tokyo_r, levels=levels, extend="both")
        cbar = fig.colorbar(color_map, ax=ax, label="log10(viscosity)")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        
        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # save figure
    if reaction_metastable_trivial:
        ofile_base = os.path.join(case_dir, "metastable_diagram_trivial_nrep_%d_vstep_%05d_raw" % (n_repetition, vtu_step))
    else:
        ofile_base = os.path.join(case_dir, "metastable_diagram_nrep_%d_vstep_%05d_raw" % (n_repetition, vtu_step))
    ofile = ofile_base + ".png"
    fig.savefig(os.path.join(ofile))
    print("Saved figure %s" % ofile)
    ofile = ofile_base + ".pdf"
    fig.savefig(os.path.join(ofile))
    print("Saved figure %s" % ofile)

    # Reset rcParams to defaults
    rcdefaults()


Plot the diagram at give step

In [None]:
if is_run_aspect_tests_1ky:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines

    T_lim = (400.0, 1800.0) # T (K)
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 200.0  # tick interval along v

    T_lim1 = (0.0, 800.0) # T (C), smaller scale
    T_tick_interval1 = 200.0  # tick interval along x

    P_lim = (10.0, 30.0) # P (Gpa)
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 5.0  # tick interval along v

    density_lim = (3000.0, 4000.0)
    density_level = 50  # number of levels in contourf plot
    density_tick_interval = 100.0  # tick interval along P

    metastable_lim = (0.0, 1.0) # metastable contents
    metastable_level = 100
    metastable_interval = 0.2

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure
    fig = plt.figure(figsize=(10, 4.25), tight_layout=True)
    gs = gridspec.GridSpec(1, 2)

    # Plot the diagram of metastable composition
    if with_metastable:
        ax = fig.add_subplot(gs[0, 0])

        levels = np.linspace(metastable_lim[0], metastable_lim[1], metastable_level)
        ticks=np.arange(metastable_lim[0], metastable_lim[1], metastable_interval)

        color_map = ax.contourf(T_grid, P_grid/1e9, metastable_grid, levels=levels,\
                                vmin=metastable_lim[0], vmax=metastable_lim[1], cmap="viridis")

        contour_099 = ax.contour(
            T_grid, P_grid / 1e9, metastable_grid,
            levels=[0.5, 0.99],
            colors=["tab:gray", 'k'],  # or any other color you prefer
            linewidths=1.5,
            linestyles='-'
        )

        cbar = fig.colorbar(color_map, ax=ax, label="Metastable")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_xlim(T_lim)
        ax.set_ylim(P_lim)

        ax.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(P_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

        ax.grid()

        ax.invert_yaxis()

        ax.set_xlabel("T (K)")
        ax.set_ylabel("P (GPa)")

        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot the diagram of density
    ax = fig.add_subplot(gs[0, 1])

    levels = np.linspace(density_lim[0], density_lim[1], density_level)
    ticks=np.arange(density_lim[0], density_lim[1], density_tick_interval)

    color_map = ax.contourf(T_grid, P_grid/1e9, density_grid, levels=levels,\
                            vmin=density_lim[0], vmax=density_lim[1], cmap=ccm.batlow)

    cbar = fig.colorbar(color_map, ax=ax, label="Density")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(T_lim)
    ax.set_ylim(P_lim)

    ax.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(P_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.invert_yaxis()

    ax.set_xlabel("T (K)")
    ax.set_ylabel("P (GPa)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    if reaction_metastable_trivial:
        ofile = os.path.join(case_dir, "metastable_diagram_trivial_nrep_%d_vstep_%05d.pdf" % (n_repetition, vtu_step))
    else:
        ofile = os.path.join(case_dir, "metastable_diagram_nrep_%d_vstep_%05d.pdf" % (n_repetition, vtu_step))

    fig.savefig(ofile)
    print("saved figure %s" % ofile)
    
    # Reset rcParams to defaults
    rcdefaults()

### Plot summary of run time

In [None]:
# case_summary.csv is deleted, so skip for now
if is_run_aspect_tests_1ky and False:
    
    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 3.0 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 250.0)
    x_tick_interval = 50.0   # tick interval along x
    y_lim = (0.0, 300.0)
    y_tick_interval = 50.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })


    import pandas as pd

    summary_file = os.path.join(case_dir, "summary.csv")

    assert(os.path.join(summary_file))

    df = pd.read_csv(summary_file)

    # Create the plot
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot each group by 'description'
    i = 0
    for key, group in df.groupby('description'):
        ax.plot(group['repitition'], group['runtime'], marker='o', linestyle="-", label=key, color=default_colors[i])
        ax.plot(group['repitition'], group['particle_update_properties_time'], marker='o', linestyle="--", label=key, color=default_colors[i])
        i += 1

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    # Axis labels and formatting
    ax.set_xlabel("Repetition")
    ax.set_ylabel("Runtime (s)")
    ax.legend()
    ax.grid(True)

    # Display the plot
    plt.tight_layout()
    plt.show()


    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    ofile = os.path.join(case_dir, "runtime_summary.pdf")
    fig.savefig(ofile)

    print("Saved figure: %s" % ofile)

    # Reset rcParams to defaults

    rcdefaults()

## Advection test

List of options:

- test_backward_advection:
  advection is backward: bottom-up. This serves to check that the metastable
  value is reset below equilibrium 

### Set up

In [None]:
is_run_aspect_tests_advection = False
is_run_aspect_tests_advection_solving = True

if is_run_aspect_tests_advection:

    with_metastable = True
    reaction_metastable_trivial = False
    with_latent_heat = True
    test_backward_advection = True

    adiabatic_surface_temperature = 1073.15 # K
    advection_rate = 0.10 # m/r

    end_time = 1e4  # yr
    maximum_time_step = 1e3 # yr

    x_extent = 2e3
    y_extent = 1000e3
    n_repetition = 4  # original - no repetition (x); documentation - 4

    aspect_dir = "/home/lochy/Softwares/aspect"
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect")
    if test_backward_advection:
        prm_template_path = os.path.join(aspect_dir, "tests", "TwoDSubduction_metastable_backward_advection.prm")
    else:
        prm_template_path = os.path.join(aspect_dir, "tests", "TwoDSubduction_metastable_advection.prm")

    assert(os.path.isfile(aspect_executable))
    assert(os.path.isfile(prm_template_path))

    # assign another directory to run the case
    case_root_dir = os.path.join("/mnt/lochz/ASPECT_DATA/TwoDSubduction/test_cases") 

In [None]:
if is_run_aspect_tests_advection:

    from hamageolib.utils.dealii_param_parser import parse_parameters_to_dict, save_parameters_from_dict
    from hamageolib.utils.world_builder_file_parser import find_feature_by_name, update_or_add_feature

    if test_backward_advection:
        case_dir = os.path.join("/mnt/lochz/ASPECT_DATA/TwoDSubduction/test_cases", "TwoDSubduction_metastable_backward_advection")  # New directory to run the case
    else:
        case_dir = os.path.join("/mnt/lochz/ASPECT_DATA/TwoDSubduction/test_cases", "TwoDSubduction_metastable_advection")  # New directory to run the case
    if not os.path.isdir(case_dir):
        os.mkdir(case_dir)

    # Modify the template
    # Also read important parameters like the size of the model

    with open(prm_template_path, 'r') as file:
        params_dict = parse_parameters_to_dict(file)

    if with_metastable:
        if reaction_metastable_trivial:
            output_dirname = "output_reaction_trivial_nrep_%d" % n_repetition  # output directory
        else:
            output_dirname = "output_nrep_%d" % n_repetition  # output directory
    else:
        output_dirname = "output_trivial_nrep_%d" % n_repetition  # output directory

    output_dirname += "_end_%.2e_maxstep_%.2e_lt_%d" % (end_time, maximum_time_step, with_latent_heat)

    params_dict["Output directory"] = os.path.join(case_dir, output_dirname)
    params_dict["End time"] = "%.2e" % end_time
    params_dict["Maximum time step"] = "%.2e" % maximum_time_step

    params_dict["Additional shared libraries"] = "$ASPECT_SOURCE_DIR/build_master_TwoD_rebase/subduction_temperature2d/libsubduction_temperature2d.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/prescribe_field/libprescribed_temperature.so, $ASPECT_SOURCE_DIR/build_master_TwoD_rebase/visco_plastic_TwoD/libvisco_plastic_TwoD.so"

    params_dict["Adiabatic surface temperature"] = str(adiabatic_surface_temperature)

    params_dict["Geometry model"]["Box"]["X extent"] = str(x_extent)
    params_dict["Geometry model"]["Box"]["Y extent"] = str(y_extent)
    params_dict["Geometry model"]["Box"]["X repetitions"] = str(n_repetition)
    params_dict["Geometry model"]["Box"]["Y repetitions"] = str(int(np.ceil(n_repetition*y_extent/x_extent)))

    if test_backward_advection:
        params_dict["Prescribed Stokes solution"]["Velocity function"]["Function expression"] = "0; %.2e" % (advection_rate)
    else:
        params_dict["Prescribed Stokes solution"]["Velocity function"]["Function expression"] = "0; %.2e" % (-advection_rate)

    if not with_metastable:
        params_dict.pop("Compositional fields")

        params_dict["Initial temperature model"].pop("Adiabatic")

        params_dict.pop("Initial composition model")
        
        params_dict["Material model"]["Visco Plastic TwoD"]["Reaction metastable"] = "false"
        params_dict["Material model"]["Visco Plastic TwoD"]["Densities"] = "background: 3300.0|3394.4|3442.1|3453.2|3617.6|3691.5|3774.7|3929.1"

        params_dict["Material model"]["Visco Plastic TwoD"].pop("Metastable transition")

        params_dict.pop("Particles")

        params_dict["Postprocess"].pop("Particles")
        
    if reaction_metastable_trivial:
        params_dict["Material model"]["Visco Plastic TwoD"]["Reaction metastable trivial"] = "true"

    if with_latent_heat:
        params_dict["Heating model"]["List of model names"] = "adiabatic heating, latent heat"
    else:
        params_dict["Heating model"]["List of model names"] = "adiabatic heating"

    params_dict["Postprocess"]["Visualization"]["Output format"] = "vtu"
    params_dict["Postprocess"]["Visualization"]["Time between graphical output"] = "%.1e" % maximum_time_step
    if with_metastable:
        params_dict["Postprocess"]["Particles"]["Data output format"] = "vtu"
        params_dict["Postprocess"]["Particles"]["Time between data output"] = "%.1e" % maximum_time_step

    # Write to a prm file in the new case directory
    prm_path = os.path.join(case_dir, "case.prm")

    with open(prm_path, 'w') as output_file:
        save_parameters_from_dict(output_file, params_dict)

    assert(os.path.isfile(prm_path))

    print("Created case in %s" % (case_dir))

### Run

Use "subprocess.run" to run the case.

Capture the standard output and error streams

Check
  * if the expected line indicating wallclock time appears in the output.
  * There is no stderr output.

In [None]:
if is_run_aspect_tests_advection and is_run_aspect_tests_advection_solving:

    # Run theb ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr == "")

Save the stdout to separate file

In [None]:
if is_run_aspect_tests_advection and is_run_aspect_tests_advection_solving:

    std_file = os.path.join(case_dir, "stdout.txt")

    with open(std_file, "w") as fout:
        fout.write(stdout)

    print("Saved stdout outputs: %s" % std_file)

### Post-process

In [None]:
if is_run_aspect_tests_advection:

    vtu_step = 0

    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator

    pvtu_file = os.path.join(case_dir, output_dirname, "solution", "solution-%05d.pvtu" % vtu_step)
    assert(os.path.isfile(pvtu_file))

    # Read the pvtu file
    start = time.time()

    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    end = time.time()
    print("Initiating reader takes %.2e s" % (end - start))
    start = end

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    n_points = grid.GetNumberOfPoints() # Number of points and cells
    n_cells = grid.GetNumberOfCells()

    end = time.time()
    print("Reading files takes %.2e s" % (end - start))
    print(f"\tNumber of points: {n_points}")
    print(f"\tNumber of cells: {n_cells}")
    print("\tAvailable point data fields:")
    for i in range(point_data.GetNumberOfArrays()):
        # Field names in point data
        name = point_data.GetArrayName(i)
        print(f"\t  - {name}")
    start = end

    # Convert data to numpy array
    # Get coordinates (points)
    # Get field "T"
    vtk_points = grid.GetPoints().GetData()
    points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    vtk_T = point_data.GetArray("T")
    vtk_p = point_data.GetArray("p")
    metastable = point_data.GetArray("metastable")
    metarate = point_data.GetArray("meta_rate")
    vtk_velocity = point_data.GetArray("velocity")
    vtk_density = point_data.GetArray("density")
    vtk_aheat = point_data.GetArray("adiabatic_heating")
    vtk_lheat = point_data.GetArray("latent_heat")
    assert(vtk_T is not None and vtk_p is not None)
    T_np = vtk_to_numpy(vtk_T)  # Shape: (n_points,)
    p_np = vtk_to_numpy(vtk_p)  # Shape: (n_points,)
    v_np = vtk_to_numpy(vtk_velocity)
    aheat_np = vtk_to_numpy(vtk_aheat)
    if with_latent_heat:
        lheat_np = vtk_to_numpy(vtk_lheat)
    density_np = vtk_to_numpy(vtk_density)

    if with_metastable:
        assert(metastable is not None)
        metastable_np = vtk_to_numpy(metastable)
        metarate_np = vtk_to_numpy(metarate)

    end = time.time()
    print("Converting data takes %.2e s" % (end - start))
    start = end

    # Calculate resolution for each cell or point in the grid
    resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2e s" % (end - start))
    start = end


    # Create interpolators for temperature, pressure, and resolution
    interpolator = LinearNDInterpolator(points_2d, T_np)  # Interpolator for temperature
    interpolator_P = LinearNDInterpolator(points_2d, p_np)  # Interpolator for pressure
    interpolator_r = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution
    interpolator_v = LinearNDInterpolator(points_2d, v_np)  # Interpolator for velocity
    interpolator_density = LinearNDInterpolator(points_2d, density_np)  # Interpolator for density
    interpolator_aheat = LinearNDInterpolator(points_2d, aheat_np)
    if with_latent_heat:
        interpolator_lheat = LinearNDInterpolator(points_2d, lheat_np)
    if with_metastable:
        interpolator_meta = LinearNDInterpolator(points_2d, metastable_np)  # Interpolator for metastable
        interpolator_metarate = LinearNDInterpolator(points_2d, metarate_np)  # Interpolator for metastable

    end = time.time()
    print("Construct linear ND interpolator takes %.2e s" % (end - start))
    start = end

In [None]:
if is_run_aspect_tests_advection:
    
    start = time.time()

    # Define the interval for the grid (in meters)
    n_x = 5
    n_y = int(np.ceil(n_x*y_extent/x_extent))

    # Determine the bounding box of the 2D points
    x_min, y_min = 0.0, 0.0
    x_max, y_max = x_extent, y_extent

    # Define a regular grid within the bounding box
    # allow a little different in interval in x
    # and y axis, thereform making the two dimensions
    # unequal to make fewer mistakes ...
    xs = np.linspace(x_min, x_max, n_x)
    ys = np.linspace(y_min, y_max, n_y)
    x_grid, y_grid = np.meshgrid(xs, ys, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([x_grid.ravel(), y_grid.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolator(grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolator_P(grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolator_r(grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(x_grid.shape)

    # Interpolate velocity (v) values onto the regular grid
    v_interp_flat = interpolator_v(grid_points_2d)
    vx_grid = v_interp_flat[:, 0].reshape(x_grid.shape)
    vy_grid = v_interp_flat[:, 1].reshape(x_grid.shape)
    # v_grid = interpolator_v((x_grid.shape[0], x_grid.shape[1], 3))
    
    # Interpolate density values onto the regular grid
    density_grid = interpolator_density(grid_points_2d)  # Use the NearestNDInterpolator
    density_grid = density_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate adiabatic heating values onto the regular grid
    aheat_grid = interpolator_aheat(grid_points_2d)
    aheat_grid = aheat_grid.reshape(x_grid.shape)  # Reshape back to the grid
    
    # Interpolate adiabatic heating values onto the regular grid
    if with_latent_heat:
        lheat_grid = interpolator_lheat(grid_points_2d)
        lheat_grid = lheat_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate metastable compositions onto the regular grid
    if with_metastable:
        metastable_grid = interpolator_meta(grid_points_2d)
        metastable_grid = metastable_grid.reshape(x_grid.shape)
        metarate_grid = interpolator_metarate(grid_points_2d)
        metarate_grid = metarate_grid.reshape(x_grid.shape)

    end = time.time()
    print("Interpolating to regular grid takes %.2e s" % (end - start))
    print("\tgrid shape: (x axis, y axis): ", x_grid.shape)
    start = end

In [None]:
if is_run_aspect_tests_advection:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from matplotlib import gridspec
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (x_min/1e3, x_max/1e3) # km
    x_tick_interval = 0.5   # tick interval along x
    y_lim = (y_min/1e3, y_max/1e3) # km
    y_tick_interval = 100.0  # tick interval along y

    resolution_lim = (0.0, 1e3) # resolution
    resolution_level = 50  # number of levels in contourf plot
    resolution_tick_interval = 0.25e3  # tick interval along v

    T_lim = (0.0, 2000.0) # T
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 250.0  # tick interval along v

    P_lim = (0.0, 40e9) # P
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 5e9  # tick interval along P

    v_lim = (-0.15, 0.15)
    v_level = 50
    v_tick_interval = 0.05

    density_lim = (3000.0, 4000.0)
    density_level = 50
    density_tick_interval = 100.0
    
    metastable_lim = (0.0, 1.0)
    metastable_level = 50
    metastable_tick_interval = 0.25
    
    metarate_lim = (0.0, 1.0)
    metarate_level = 50
    metarate_tick_interval = 0.25
    
    heating_lim = (-1e-4, 1e-4)
    heating_level = 50
    heating_tick_interval = 2.5e-5

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure with a 2x2 grid layout
    fig = plt.figure(figsize=(12, 25), tight_layout=True)
    gs = gridspec.GridSpec(5, 2)

    # Plot the mesh resolution
    ax = fig.add_subplot(gs[0, 0])

    levels = np.linspace(resolution_lim[0], resolution_lim[1], resolution_level)
    ticks=np.arange(resolution_lim[0], resolution_lim[1], resolution_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, resolutions_grid,  vmin=resolution_lim[0], vmax=resolution_lim[1], levels=levels, cmap="plasma_r")  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="Resolution")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot T
    ax = fig.add_subplot(gs[0, 1])

    levels = np.linspace(T_lim[0], T_lim[1], T_level)
    ticks=np.arange(T_lim[0], T_lim[1], T_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, T_grid,  vmin=T_lim[0], vmax=T_lim[1], levels=levels, cmap=ccm.lapaz)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="T")  # Add colorbar
    cbar.set_ticks(ticks)

    # ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)


    # Plot P
    ax = fig.add_subplot(gs[1, 0])

    levels = np.linspace(P_lim[0], P_lim[1], P_level)
    ticks=np.arange(P_lim[0], P_lim[1], P_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, P_grid,  vmin=P_lim[0], vmax=P_lim[1], levels=levels, cmap=ccm.tokyo_r)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="P")  # Add colorbar
    cbar.set_ticks(ticks)

    # ax.set_aspect("equal", adjustable="box")  # Equal aspect ratio

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot vx
    ax = fig.add_subplot(gs[1, 1])

    levels = np.linspace(v_lim[0], v_lim[1], v_level)
    ticks=np.arange(v_lim[0], v_lim[1], v_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, vx_grid,  vmin=v_lim[0], vmax=v_lim[1], levels=levels, cmap=ccm.hawaii)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="vx")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot vy
    ax = fig.add_subplot(gs[2, 0])

    levels = np.linspace(v_lim[0], v_lim[1], v_level)
    ticks=np.arange(v_lim[0], v_lim[1], v_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, vy_grid,  vmin=v_lim[0], vmax=v_lim[1], levels=levels, cmap=ccm.hawaii)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="vy")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # plot density
    ax = fig.add_subplot(gs[2, 1])

    levels = np.linspace(density_lim[0], density_lim[1], density_level)
    ticks=np.arange(density_lim[0], density_lim[1], density_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, density_grid,  vmin=density_lim[0], vmax=density_lim[1], levels=levels, cmap=ccm.batlow)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="density")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # plot metastable
    if with_metastable:
        ax = fig.add_subplot(gs[3, 0])

        levels = np.linspace(metastable_lim[0], metastable_lim[1], metastable_level)
        ticks=np.arange(metastable_lim[0], metastable_lim[1], metastable_tick_interval)

        color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, metastable_grid,  vmin=metastable_lim[0], vmax=metastable_lim[1], levels=levels, cmap="viridis")  # Metastable color map
        cbar = fig.colorbar(color_map, ax=ax, label="metastable")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X (km)")
        ax.set_ylabel("Depth (km)")

        ax.invert_yaxis()

        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    
    # plot adiabatic heating
    ax = fig.add_subplot(gs[3, 1])

    levels = np.linspace(heating_lim[0], heating_lim[1], heating_level)
    ticks=np.arange(heating_lim[0], heating_lim[1], heating_tick_interval)

    color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, aheat_grid, cmap=ccm.glasgow, vmin=heating_lim[0], vmax=heating_lim[1], levels=levels)  # Resolution colormap
    cbar = fig.colorbar(color_map, ax=ax, label="adiabatic heating")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)

    ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

    ax.set_xlabel("X (km)")
    ax.set_ylabel("Depth (km)")

    ax.invert_yaxis()

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # plot letent heating
    if with_latent_heat:
        ax = fig.add_subplot(gs[4, 0])

        levels = np.linspace(heating_lim[0], heating_lim[1], heating_level)
        ticks=np.arange(heating_lim[0], heating_lim[1], heating_tick_interval)
        color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, lheat_grid,\
                                cmap=ccm.glasgow, vmin=heating_lim[0], vmax=heating_lim[1], levels=levels)
        
        # levels = np.linspace(np.min(lheat_grid), np.max(lheat_grid), heating_level)
        # ticks = np.linspace(np.min(lheat_grid), np.max(lheat_grid), 10)
        # color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, lheat_grid,\
        #                         cmap=ccm.glasgow, vmin=np.min(lheat_grid), vmax=np.max(lheat_grid), levels=levels,\
        #                             extend="both")

        cbar = fig.colorbar(color_map, ax=ax, label="latent heat")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X (km)")
        ax.set_ylabel("Depth (km)")

        ax.invert_yaxis()

        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

        # plot metastable
    if with_metastable:
        ax = fig.add_subplot(gs[4, 1])

        # levels = np.linspace(metarate_lim[0], metarate_lim[1], metarate_level)
        # ticks=np.arange(metarate_lim[0], metarate_lim[1], metarate_tick_interval)
        levels = np.linspace(0.0, np.max(metarate_grid), 50)
        ticks = np.linspace(0.0, np.max(metarate_grid), 10)
        color_map = ax.contourf(x_grid/1e3, (y_extent-y_grid)/1e3, metarate_grid, cmap=ccm.buda, vmin=0.0, vmax=np.max(metarate_grid), level=levels) 
        cbar = fig.colorbar(color_map, ax=ax, label="meta rate")  # Add colorbar
        cbar.set_ticks(ticks)

        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)

        ax.xaxis.set_major_locator(MultipleLocator(x_tick_interval))
        ax.xaxis.set_minor_locator(MultipleLocator(x_tick_interval/(n_minor_ticks+1)))
        ax.yaxis.set_major_locator(MultipleLocator(y_tick_interval))
        ax.yaxis.set_minor_locator(MultipleLocator(y_tick_interval/(n_minor_ticks+1)))

        ax.set_xlabel("X (km)")
        ax.set_ylabel("Depth (km)")

        ax.invert_yaxis()

        for spine in ax.spines.values():
            # Adjust spine thickness for this plot
            spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    if with_metastable:
        if reaction_metastable_trivial:
            ofile_name = os.path.join(case_dir, "metastable_advection_trivial_adT_%.1f_adv_%.3f_end_%.2e_maxstep_%.2e_vstep_%d_lt_%d" % (adiabatic_surface_temperature, advection_rate, end_time, maximum_time_step, vtu_step, with_latent_heat))
        else:
            ofile_name = os.path.join(case_dir, "metastable_advection_adT_%.1f_adv_%.3f_end_%.2e_maxstep_%.2e_vstep_%d_lt_%d" % (adiabatic_surface_temperature, advection_rate, end_time, maximum_time_step, vtu_step, with_latent_heat))
    else:
        ofile_name = os.path.join(case_dir, "trivial_advection_adT_%.1f_adv_%.3f_end_%.2e_maxstep_%.2e_vstep_%d_lt_%d" % (adiabatic_surface_temperature, advection_rate, end_time, maximum_time_step, vtu_step, with_latent_heat))

    fig.savefig(ofile_name + ".png")
    print("Saved figure %s" % (ofile_name + ".png"))
    fig.savefig(ofile_name + ".pdf")
    print("Saved figure %s" % (ofile_name + ".pdf"))

    # Reset rcParams to defaults
    rcdefaults()

# ASPECT Cases

## Viscosity profile

In [None]:
# todo_visc
test_viscosity_profile = False

if test_viscosity_profile:

    viscosity_jump_type = "1100i" # 660 or 1100

    from hamageolib.research.haoyuan_2d_subduction.legacy_tools import RHEOLOGY_PRM, RHEOLOGY_OPR, RefitRheology

    # constant variables
    rheology_name = "WarrenHansen23"
    mantle_coh = 300.0
    strain_rate = 1e-15

    if viscosity_jump_type == "660":
        jump_lower_mantle = 60.0
        Vdiff_lm = 3e-6
        depth_lm_middle = -1.0
    elif viscosity_jump_type == "1100":
        jump_lower_mantle = 0.5
        Vdiff_lm = 9e-6
        depth_lm_middle = 1100e3
    elif viscosity_jump_type == "1100i":
        jump_lower_mantle = 5
        Vdiff_lm = 6e-6
        depth_lm_middle = 1100e3
    else:
        raise NotImplementedError()
    
    rheology_prm_dict = RHEOLOGY_PRM()
    Operator = RHEOLOGY_OPR()

    # import a depth average profile
    LEGACY_FILE_DIR = os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files")
    da_file = os.path.join(LEGACY_FILE_DIR, 'reference_ThD', "depth_average_1573.txt")
    Operator.ReadProfile(da_file)
    depths, pressures, temperatures = Operator.depths, Operator.pressures, Operator.temperatures

    T660, P660 = np.interp(660e3, depths, temperatures), np.interp(660e3, depths, pressures)
    T1500, P1500 = np.interp(1500e3, depths, temperatures), np.interp(1500e3, depths, pressures)

    # initial rheologic parameters
    diffusion_creep_ori = getattr(rheology_prm_dict, rheology_name + "_diff")
    dislocation_creep_ori = getattr(rheology_prm_dict, rheology_name + "_disl")
    rheology_dict = {'diffusion': diffusion_creep_ori, 'dislocation': dislocation_creep_ori}
    # prescribe the correction
    diff_correction = {'A': 1.0, 'p': 0.0, 'r': 0.0, 'n': 0.0, 'E': 0.0, 'V': -2.1e-6}
    disl_correction = {'A': 1.0, 'p': 0.0, 'r': 0.0, 'n': 0.0, 'E': 0.0, 'V': 3e-6}
    # prescribe the reference state
    ref_state = {}
    ref_state["Coh"] = mantle_coh # H / 10^6 Si
    ref_state["stress"] = 50.0 # MPa
    ref_state["P"] = 100.0e6 # Pa
    ref_state["T"] = 1250.0 + 273.15 # K
    ref_state["d"] = 15.0 # mu m
    # refit rheology
    rheology_dict_refit = RefitRheology(rheology_dict, diff_correction, disl_correction, ref_state)
    # derive mantle rheology
    rheology, viscosity_profile = Operator.MantleRheology(assign_rheology=True, diffusion_creep=rheology_dict_refit['diffusion'],\
                                                dislocation_creep=rheology_dict_refit['dislocation'], save_profile=0,\
                                                use_effective_strain_rate=True, save_json=1, Coh=mantle_coh,\
                                                jump_lower_mantle=jump_lower_mantle, Vdiff_lm=Vdiff_lm, depth_lm_middle=depth_lm_middle)

In [None]:
if test_viscosity_profile:

    from hamageolib.research.haoyuan_2d_subduction.legacy_tools import CreepRheologyInAspectViscoPlastic

    diff_um = rheology["diffusion_creep"]
    disl_um = rheology["dislocation_creep"]

    visc_660_diff_um = CreepRheologyInAspectViscoPlastic(diff_um, strain_rate, P660, T660)
    visc_660_disl_um = CreepRheologyInAspectViscoPlastic(disl_um, strain_rate, P660, T660)
    visc_660_um = 1.0 / (1.0/visc_660_diff_um + 1.0/visc_660_disl_um)

    diff_lm = rheology["diffusion_lm"]

    visc_660_lm = CreepRheologyInAspectViscoPlastic(diff_lm, strain_rate, P660, T660)


    print("visc_660_diff_um: ", visc_660_diff_um)
    print("visc_660_disl_um: ", visc_660_disl_um)
    print("visc_660_um: ", visc_660_um)

    print("visc_660_lm: ", visc_660_lm)

    print(rheology)

In [None]:
if test_viscosity_profile:

    from matplotlib import rcdefaults

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 1.0
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines
    x_lim = (0.0, 10.0)
    x_tick_interval = 2.0   # tick interval along x
    y_lim = (0.0, 100.0)
    y_tick_interval = 20.0  # tick interval along y
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })


    ymax = 2890.0 # km

    ylim=[ymax, 0.0]
    masky = (depths/1e3 < ymax)

    # get diffusion and dislocation profile
    eta_diff = viscosity_profile["diffusion"]
    eta_disl = viscosity_profile["dislocation"]
    eta = viscosity_profile["composite"]
    eta13 = viscosity_profile["composite_13"]
    eta_disl13 = viscosity_profile["dislocation_13"]

    fig, axs = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)

    # pressure
    color = 'tab:blue'
    axs[0].plot(pressures/1e9, depths/1e3, color=color, label='pressure')
    axs[0].set_ylabel('Depth [km]') 
    axs[0].set_xlabel('Pressure [GPa]', color=color)
    Pmax = np.ceil(np.max(pressures[masky]/1e9) / 10.0) *10.0
    axs[0].set_xlim([0.0, Pmax])
    # axs[0].invert_yaxis()
    axs[0].set_ylim(ylim)

    # ax2: temperature
    color = 'tab:red'
    ax2 = axs[0].twiny()
    ax2.set_ylim(ylim)
    ax2.plot(temperatures, depths/1e3, color=color, label='temperature')
    Tmax = np.ceil(np.max(temperatures[masky]) / 100.0) *100.0
    ax2.set_xlim([0.0, Tmax])
    ax2.set_xlabel('Temperature [K]', color=color) 

    # second: viscosity
    #   upper mantle
    axs[1].semilogx(eta_diff, depths/1e3, 'c', label='diffusion creep')
    axs[1].semilogx(eta_disl, depths/1e3, 'g', label='dislocation creep(%.2e)' % strain_rate)
    axs[1].semilogx(eta, depths/1e3, 'r--', label='Composite')
    axs[1].set_xlim([1e19,1e24])
    axs[1].set_ylim(ylim)
    axs[1].grid()
    axs[1].set_ylabel('Depth [km]')
    axs[1].legend()

    # third: viscosity at 1e13
    axs[2].semilogx(eta_diff, depths/1e3, 'c', label='diffusion creep')
    axs[2].semilogx(eta_disl13, depths/1e3, 'g', label='dislocation creep(%.2e)' % 1e-13)
    axs[2].semilogx(eta13, depths/1e3, 'r--', label='Composite')
    axs[2].set_xlim([1e19,1e24])
    axs[2].set_ylim(ylim)
    axs[2].grid()
    axs[2].set_ylabel('Depth [km]')
    axs[2].legend()

    fig_path=os.path.join(results_dir, "viscosity_profile_combined_%s.pdf" % viscosity_jump_type)
    fig.savefig(fig_path)
    print("Saved figure %s" % fig_path)


    # Adjust spine thickness for this plot
    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Reset rcParams to defaults

    rcdefaults()

## 2D Case

### Case Options

- option 0: case is adapted from F100sa80oa40Rwedge in the twod case
- option 1: case is adapted from C2d_SA80_OA40_l8896_h1000_s300 in the 3d-consistent 2d case

In [None]:
def mow_case_name(case_name_base, geometry, box_height, include_metastable,\
                  global_refine_level ,adaptive_refine_level, viscosity_jump_type,\
                    shear_zone_thickness, shear_zone_visc):
    
    import math

    # fix case_name
    if geometry == "box":
        case_name = "C_%s_h%.1f" % (case_name_base, box_height/1e3)
    else:
        case_name = "Sp_%s_h%.1f" % (case_name_base, box_height/1e3)
    if include_metastable:
        case_name = "%s_M" % case_name
    else:
        case_name = case_name
    case_name = "%s_gr%d_ar%d" % (case_name, global_refine_level, adaptive_refine_level)

    # viscosity change
    if viscosity_jump_type != "660":
        case_name = "%s_jp%s" % (case_name, viscosity_jump_type)

    # shear zone
    if not math.isclose(shear_zone_thickness, 15e3, rel_tol=1e-9):
        case_name = "%s_szT%.2f" % (case_name, shear_zone_thickness/1e3)
    
    if not math.isclose(np.log10(shear_zone_visc), 20.0, rel_tol=1e-9):
        case_name = "%s_szV%.2e" % (case_name, shear_zone_visc)

    return case_name

In [None]:
run_aspect_2d_test = False
run_aspect_2d_test_first_step = False

if run_aspect_2d_test:

    # paths 
    aspect_dir = "/home/lochy/Softwares/aspect" # aspect directory (local, optional if not ran locally)
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect") # build direcotory (local, optional if not ran locally)
    slurm_base_path = os.path.join(root_path, "scripts/slurm/250816/job_hive_high.sh")

    # options
    case_option_id = 1 # 0 - same to the shear zone model, 1 - same to the 3-d model
    include_metastable = 0  # 1 - include metastable
    case_root_dir = "/mnt/lochy/ASPECT_DATA/MOW/no_mow_sz_jump"  # parent directory
    case_name_base = "mow" # case directory
    geometry = "box"; box_height = 2890e3 # geometry setup
    global_refine_level = 3
    adaptive_refine_level = 5

    # viscosity profile
    viscosity_jump_type = "1100i" # 660 or 1100 or 1100i

    # shear zone
    shear_zone_thickness = 7.5e3
    shear_zone_visc = 5e19
    case_name = mow_case_name(case_name_base, geometry, box_height, include_metastable,\
                              global_refine_level ,adaptive_refine_level, viscosity_jump_type,\
                                shear_zone_thickness, shear_zone_visc)

    if viscosity_jump_type == "660":
        mantle_rheology_dict = {
            "scheme": "HK03_WarrenHansen23",
            "flow law": "composite",
            "adjust detail": 1,
            "jump lower mantle": 60.0,
            "Coh": 300.0,
            "use 3d da file": 1
          }
    elif viscosity_jump_type == "1100" :
        mantle_rheology_dict = {
            "scheme": "HK03_WarrenHansen23",
            "flow law": "composite",
            "adjust detail": 1,
            "jump lower mantle": 0.5,
            "Coh": 300.0,
            "use 3d da file": 1,
            "use 3d da file whole mantle": 1,
            "depth lm middle": 1100e3,
            "V lm": 9e-6,
            "V lm middle": 3e-6
          }
    elif viscosity_jump_type == "1100i" :
        mantle_rheology_dict = {
            "scheme": "HK03_WarrenHansen23",
            "flow law": "composite",
            "adjust detail": 1,
            "jump lower mantle": 5.0,
            "Coh": 300.0,
            "use 3d da file": 1,
            "use 3d da file whole mantle": 1,
            "depth lm middle": 1100e3,
            "V lm": 6e-6,
            "V lm middle": 3e-6
        }
    else:
        raise NotImplementedError()  

    case_options0 = {
      "base directory": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_TwoD/240106"), 
      "branch": "master_TwoD_rebase_dealii-9.5",
      "output directory": case_root_dir,
      "name": case_name,
      "depth average file": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_TwoD/depth_average.txt"),
      "include fast first step": 1,
      "version": 3.0,
      "test initial steps": {
        "number of outputs": 3,
        "interval of outputs": 10000.0
      },
      "geometry": "box",
      "potential temperature": 1573.0,
      "boundary condition": {
        "model": "all free slip"
      },
      "use world builder": 1,
      "world builder": {
        "use new ridge implementation": 1,
        "plate age method": "adjust box width",
        "box width before adjusting": 15570000.0,
        "adjust mesh with box width": 1,
        "subducting plate": {
          "age trench": 80000000.0,
          "sp rate": 0.05
        },
        "overiding plate": {
          "age": 40000000.0,
          "transit": {
            "age": 20000000.0,
            "length": 700000.0
          }
        }
      },
      "use new rheology module": 1,
      "coupling the eclogite phase to shear zone viscosity": 0,
      "slurm": [
        {
          "slurm file": slurm_base_path,
          "build directory": "master_TwoD_rebase_dealii-9.5",
          "tasks per node": 8,
          "cpus": 8
        }
      ],
      "mantle rheology": {
        "scheme": "HK03_WarrenHansen23",
        "Coh": 500.0,
        "delta Edisl": 0.0
      },
      "include peierls creep": 1,
      "peierls creep": {
        "scheme": "MK10",
        "maximum peierls iterations": 100,
        "fix peierls V as": "dislocation"
      },
      "refinement level": global_refine_level + adaptive_refine_level,
      "shear zone": {
        "thickness": shear_zone_thickness,
        "constant viscosity": shear_zone_visc,
        "cutoff depth": 100000.0,
        "thickness": 7500.0
      },
      "phase transition model CDPT type": "HeFESTo_consistent",
      "prescribe temperature method": "plate model 1",
      "prescribe temperature width": 400000.0,
      "outputs": {
        "heat flux": 1
      },
      "refinement": {
        "refine wedge": 1
      },
      "composition method": {
        "scheme": "particle",
        "duplicate op composition": 1,
      },
      'metastable': {
        "include metastable": include_metastable
      }
    }

    case_options1 = {
      "base directory": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_TwoD/240106"), 
      "branch": "master_TwoD_rebase_dealii-9.5",
      "output directory": case_root_dir,
      "name": case_name,
      "depth average file": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_ThD/depth_average_1573.txt"),
      "include fast first step": 1,
      "version": 3.0,
      "test initial steps": {
        "number of outputs": 3,
        "interval of outputs": 10000.0
      },
      "geometry": geometry,
      "potential temperature": 1573.0,
      "boundary condition": {
        "model": "all free slip"
      },
      "use world builder": 1,
      "world builder": {
        "use new ridge implementation": 1,
        "plate age method": "adjust box width only assigning age",
        "box width before adjusting": 8896000.0,
        "adjust mesh with box width": 1,
        "subducting plate": {
          "age trench": 80000000.0,
          "sp rate": 0.05,
          "trailing length": 600000.0
        },
        "overiding plate": {
          "age": 40000000.0,
          "transit": {
            "age": 20000000.0,
            "length": 700000.0
          },
          "trailing length": 600000.0
        },
        "maximum repetition slice": 1000000.0,
        "fix boudnary temperature auto": 1,
        "box height": box_height
      },
      "coupling the eclogite phase to shear zone viscosity": 0,
      "slurm": [
        {
          "slurm file": slurm_base_path,
          "build directory": "master_TwoD_rebase_dealii-9.5",
          "tasks per node": 8,
          "cpus": 8
        }
      ],
      "use new rheology module": 1,
      "mantle rheology": mantle_rheology_dict,
      "include peierls creep": 1,
      "peierls creep": {
        "scheme": "MK10",
        "maximum peierls iterations": 100,
        "fix peierls V as": "dislocation"
      },
      "refinement level": global_refine_level + adaptive_refine_level,
      "minimum viscosity": 1e+19,
      "refinement scheme": "3d consistent",
      "reset density": 1,
      "refinement": {
        "global refinement": global_refine_level,
        "adaptive refinement": adaptive_refine_level
      },
      "phase transition model CDPT type": "HeFESTo_consistent",
      "shear zone": {
        "thickness": shear_zone_thickness,
        "constant viscosity": shear_zone_visc,
        "slab core viscosity": 1e+22
      },
      "prescribe temperature method": "plate model 1",
      "prescribe temperature width": 900000.0,
      "prescribe temperature with trailing edge": 1,
      "slab": {
        "strength": 300000000.0
      },
      "composition method": {
        "scheme": "particle",
        "duplicate op composition": 1,
      },
      'metastable': {
        "include metastable": include_metastable
      }
    }

    if case_option_id == 0:
        case_options = case_options0
    elif case_option_id == 1:
        case_options = case_options1

### Create the Case

In [None]:
if run_aspect_2d_test:
    from hamageolib.research.haoyuan_2d_subduction.legacy_tools import create_case_with_json, CASE_TWOD, CASE_OPT_TWOD

    if not os.path.isdir(case_root_dir):
        os.mkdir(case_root_dir)

    create_case_with_json(case_options, CASE_TWOD, CASE_OPT_TWOD)

### Run the First step

In [None]:
if run_aspect_2d_test and run_aspect_2d_test_first_step:

    case_dir = os.path.join(case_root_dir, "test_foo1")
    prm_path = os.path.join(case_dir, "case_ini.prm")

    # Run the ASPECT executable with the parameter file
    # The function ensures that both the expected outputs are generated and no errors are produced
    # 'capture_output=True' collects both stdout and stderr for further checks
    # 'cwd' set the run from case_dir
    completed_process = subprocess.run([aspect_executable, prm_path], capture_output=True, text=True, cwd=case_dir)

    # Capture the standard output and error streams
    stdout = completed_process.stdout
    stderr = completed_process.stderr

    # Uncomment the following lines for debugging purposes to inspect the output
    # print(stdout)  # Debugging: Prints the standard output
    # print(stderr)  # Debugging: Prints the standard error

    # Check if the expected line indicating wallclock time appears in the output
    # The expected line format is something like:
    # -- Total wallclock time elapsed including restarts: 1s
    assert(re.match(".*Total wallclock", stdout.split('\n')[-6]))

    # Ensure that the error stream is empty, indicating no issues during the run
    assert(stderr == "")

### Visualize the results

make use of the jupyter_notebooks/TwoDSubduction/PlotCase.ipynb notebook

## 3D Case

### Case Options

In [None]:
run_aspect_3d_test = False
run_aspect_3d_test_first_step = False

if run_aspect_3d_test:

    aspect_dir = "/home/lochy/Softwares/aspect"
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect")

    case_root_dir = "/mnt/lochy/ASPECT_DATA/MOW/mow3_00"
    case_name_base = "mow" # case directory

    include_metastable = 1  # 1 - include metastable
    geometry = "box"; box_height = 2890e3 # geometry setup, 1000e3, 2890e3
    global_refine_level = 3
    adaptive_refine_level = 4
    aspect_dir = "/home/lochy/Softwares/aspect" # aspect directory (local, optional if not ran locally)
    aspect_executable = os.path.join(aspect_dir, "build_master_TwoD_rebase/aspect") # build direcotory (local, optional if not ran locally)

    case_name = mow_case_name(case_name_base, geometry, box_height, include_metastable, global_refine_level ,adaptive_refine_level)

    case_options = {
    "_comments": "This case is modified from the Schellart 2007 paper, but include newtonian T-dependent rheology",
    "base directory": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_ThD/07062024"),
    "depth average file": os.path.join(root_path, "hamageolib/research/haoyuan_2d_subduction/legacy_files/reference_ThD/depth_average_1573.txt"),
    "output directory": case_root_dir,
    "version": 3.0,
    "name": case_name,
    "type": "2d_consistent",
    "use world builder": 1,
    "branch": "master_TwoD_rebase_dealii-9.5",
    "post process": {
        "visualization software": "paraview"
    },
    "world builder": {
        "use new ridge implementation": 1
    },
    "include fast first step": 1,
    "geometry": geometry,
    "geometry setup": {
        "box width": 4000000.0,
        "box length": 5000000.0,
        "box height": box_height,
        "box length before adjusting": 8896000.0,
        "adjust box trailing length": 1,
        "repitition slice method": "nearest",
        "fix boudnary temperature auto": 1
    },
    "plate setup": {
        "sp width": 1000000.0,
        "sp length": 3000000.0,
        "trailing length": 0.0,
        "trailing length 1": 600000.0,
        "reset trailing morb": 3,
        "sp depth refining": 300000.0,
        "ov age": 40000000.0,
        "sp age": 80000000.0,
        "assign side plate": 1,
        "ov transit age": 20000000.0,
        "ov transit length": 700000.0,
        "sp ridge x": 0.0,
        "prescribe mantle sp start": 0,
        "ov side dist": 0.0,
        "prescribe mantle ov end": 1,
        "include ov upper plate": 1,
        "strength": 300000000.0
    },
    "use new rheology module": 1,
    "mantle rheology": {
        "scheme": "HK03_WarrenHansen23",
        "flow law": "composite",
        "adjust detail": 1,
        "jump lower mantle": 60.0,
        "Coh": 300.0
    },
    "include peierls creep": 1,
    "peierls creep": {
        "fix peierls V as": "dislocation"
    },
    "shear zone": {
        "thickness": 15000.0,
        "slab core viscosity": 1e+22
    },
    "slab setup": {
        "length": 530000.0,
        "dip": 70.0
    },
    "refinement": {
        "global refinement": global_refine_level,
        "adaptive refinement": adaptive_refine_level,
        "coarsen minimum refinement level": 2
    },
    "rheology": {
        "reset trailing ov viscosity": 0
    },
    "setup method": "2d_consistent",
    "stokes solver": {
        "type": "block GMG with iterated defect correction Stokes"
    },
    "slurm": [
        {
        "slurm file": os.path.join(root_path, "tests/fixtures/research/haoyuan_2d_subduction/slurm_files/230924/job_frontera-normal.sh"),
        "build directory": "master_TwoD_rebase_dealii-9.5",
        "tasks per node": 56,
        "cpus": 1120
        }
    ],
    "make 2d consistent plate": 2,
    "composition method": {
        "scheme": "particle"
      },
    'metastable': {
        "include metastable": include_metastable
    }
    }


### Create the Case

In [None]:
if run_aspect_3d_test:
    from hamageolib.research.haoyuan_2d_subduction.legacy_tools import create_case_with_json, CASE_THD, CASE_OPT_THD

    if not os.path.isdir(case_root_dir):
        os.mkdir(case_root_dir)

    create_case_with_json(case_options, CASE_THD, CASE_OPT_THD)

# Post-Process

In this section, I handle the post-processing of metatsable cases.

* Visualization: I use a combined workflow of python (mainly pyvista) + paraview + Adobe Illustrator to generate plots

Note:
- do_post_process: control the running the the whole section

In [None]:
do_post_process = True

if do_post_process:

    import shutil, math
    from shutil import rmtree, copy
    from matplotlib import gridspec, cm
    from PIL import Image, ImageDraw, ImageFont
    from scipy.interpolate import interp1d, UnivariateSpline
    import datetime

    # Working directories
    local_MOW_dir = "/mnt/lochy/ASPECT_DATA/MOW"
    assert(os.path.isdir(local_MOW_dir))

    local_ThD_dir = "/mnt/lochy/ASPECT_DATA/ThDSubduction"
    assert(os.path.isdir(local_ThD_dir))
    use_3d_case = False # use cases in the old project

    # py_temp file and temperature results directory
    py_temp_dir = os.path.join(root_path, "py_temp_files")
    RESULT_DIR = os.path.join(root_path, 'results')
    os.makedirs(py_temp_dir, exist_ok=True) # Ensure the directory exists

    today_date = datetime.datetime.today().strftime("%Y-%m-%d") # Get today's date in YYYY-MM-DD format
    py_temp_file = os.path.join(py_temp_dir, f"py_temp_{today_date}.sh")

    if not os.path.exists(py_temp_file):
        bash_header = """#!/bin/bash
    # =====================================================
    # Script: py_temp.sh
    # Generated on: {date}
    # Description: Temporary Bash script created by Python
    # =====================================================

    """.format(date=today_date)
        with open(py_temp_file, "w") as f:
            f.write(bash_header)

    print(f"File ensured at: {py_temp_file}")

## Case name

In [None]:
if do_post_process:
    # case_name_2d = "mow_tests/eba2d_width80_h1000_bw4000_sw1000_yd300_M_fix_1"
    # case_name = None; case_name_2d = "mow00/C_mow_h2890.0_M_gr4_ar5"

    # 1000 km
    # without metastable
    # case_name = "mow3_00/C_mow_h1000.0_gr3_ar4"; case_name_2d = "mow_tests/eba2d_width80_h1000_bw4000_sw1000_yd300"
    # with metastable
    # case_name = "mow3_00/C_mow_h1000.0_M_gr3_ar4"; casa_name_2d = None
    # case_name = None; case_name_2d = "mow00/C_mow_h2890.0_M_gr4_ar5"
    # case_name_2d = "mow00/C_mow_h2890.0_gr4_ar5"
    
    # Cases with full domain 
    # LABEL: C_mow_gr3_ar4
    # case_name = None; case_name_2d = "mow01/C_mow_h2890.0_gr3_ar4" # meta, low r
    # LABEL: C_mow_M_gr3_ar4
    # case_name = "EBA_2d_consistent_8_6/eba3d_width80_c22_AR4_yd300"; use_3d_case = True; case_name_2d = "mow01/C_mow_h2890.0_M_gr3_ar4" # set PTs for overring plate compositions
    # case_name = None; case_name_2d = "mow00/C_mow_h2890.0_gr3_ar4" # non-meta
    # LABEL: C_mow_gr3_ar4_j1100
    # case_name = None; case_name_2d = "mow01/C_mow_h2890.0_gr3_ar4_jp1100"
    # LABEL: C_mow_M_gr3_ar4_j1100
    # case_name = None; case_name_2d = "mow01/C_mow_h2890.0_M_gr3_ar4_jp1100"
    # LABEL: C_mow_gr3_ar4_j1100i
    # case_name = None; case_name_2d = "mow01/C_mow_h2890.0_gr3_ar4_jp1100i"
    # LABEL: C_mow_M_gr3_ar4_j1100i
    # case_name = None; case_name_2d = "mow01/C_mow_h2890.0_M_gr3_ar4_jp1100i"
    # LABEL: C_mow_M_gr3_ar4_j1100i_szT7.5
    case_name = None; case_name_2d = "mow01/C_mow_h2890.0_M_gr3_ar5_jp1100i_szT7.50_szV5.00e+19"
    # LABEL: C_mow_gr3_ar4_j1100i_szT7.5
    # case_name = None; case_name_2d = "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_jp1100i_szT7.50_szV5.00e+19"

    local_dir_2d = None; local_dir = None
    if case_name_2d is not None:
        local_dir_2d = os.path.join(local_MOW_dir, case_name_2d)
        assert(os.path.isdir(local_dir_2d))
        print("local_dir_2d:\n\t", local_dir_2d)
        img_dir = os.path.join(local_dir_2d, "img")
        if not os.path.isdir(img_dir):
            os.mkdir(img_dir)
        pv_img_dir = os.path.join(img_dir, "pv_outputs")
        if not os.path.isdir(pv_img_dir):
            os.mkdir(pv_img_dir)
    if case_name is not None:
        if use_3d_case:
            local_dir = os.path.join(local_ThD_dir, case_name)
        else:
            local_dir = os.path.join(local_MOW_dir, case_name)
        assert(os.path.isdir(local_dir))
        print("local_dir:\n\t", local_dir)
        img_dir = os.path.join(local_dir, "img")
        if not os.path.isdir(img_dir):
            os.mkdir(img_dir)
        pv_img_dir = os.path.join(img_dir, "pv_outputs")
        if not os.path.isdir(pv_img_dir):
            os.mkdir(pv_img_dir)

## Visualization

### 2-d case

For the 2-d case, I use a combined pyvista + paraview workflow.

- Analysis: this is handled in pyvista (e.g. trench position, slab depth, etc.)
- Generating script: using python to generate script for paraview
- Visualization: running script in paraview

#### Generate plotting scripts

In [None]:
# plot the 2d case: prepare data and paraview script
is_prepare_for_plot_2d = False
is_process_pyvista_for_plot_2d = False

if do_post_process and is_prepare_for_plot_2d:
    from hamageolib.research.mow_subduction.case_options import CASE_OPTIONS_TWOD
    from hamageolib.research.haoyuan_3d_subduction.post_process import ProcessVtuFileTwoDStep
    

    assert(local_dir_2d is not None)

    # parameters
    graphical_steps = [99] # specify steps
    rotation_plus = 0.47 # rotation of the frame along the lon when making plot
    max_depth = "1000"  # maximum plot depth, 1000, 1300, or 1500

    assert(max_depth in ["1000", "1300", "1500"])

    # case options 
    Case_Options_2d = CASE_OPTIONS_TWOD(local_dir_2d)
    Case_Options_2d.Interpret()
    Case_Options_2d.SummaryCaseVtuStep(os.path.join(local_dir_2d, "summary.csv"))
    Case_Options_2d.SummaryCaseVtuStepExport(os.path.join(local_dir_2d, "summary.csv"))

    # Processing pyvista
    if is_process_pyvista_for_plot_2d:
        for step in graphical_steps:
            pvtu_step = step + int(Case_Options_2d.options['INITIAL_ADAPTIVE_REFINEMENT'])
            output_dict = ProcessVtuFileTwoDStep(local_dir_2d, pvtu_step, Case_Options_2d)
            Case_Options_2d.SummaryCaseVtuStepUpdateValue("Slab depth", step, output_dict["slab_depth"])
            Case_Options_2d.SummaryCaseVtuStepUpdateValue("Trench", step, output_dict["trench_center"])

            # print("metastable_area: %.2e km^2" % (output_dict["metastable_area"]/1e6))
    
    # Generate paraview script
    for step in graphical_steps:
        my_assert(len(graphical_steps)==1, ValueError, "Feeding the trench position only works when there is only one step")

        # Get time 
        idx = Case_Options_2d.summary_df["Vtu step"] == step
        _time = Case_Options_2d.summary_df.loc[idx, "Time"].values[0]
        pvtu_step = step + int(Case_Options_2d.options['INITIAL_ADAPTIVE_REFINEMENT']) 
        pyvista_outdir = os.path.join(local_dir_2d, "pyvista_outputs", "%05d" % pvtu_step)

        # Get trench center
        # trench_initial = Case_Options_2d.summary_df.loc[0, "Trench"] # there is issue with this
        trench_center = Case_Options_2d.summary_df.loc[idx, "Trench"].values[0]

        # Apply steps
        Case_Options_2d.Interpret(steps=[step])
    
        # Add additional outputs
        additional_options = {"TRENCH_CENTER": trench_center, "TRENCH_INI_DERIVED": 0.0} # 0.0: initial trench center issue
        for key, value in additional_options.items():
            Case_Options_2d.options[key] = value

        # Add additonal plot options
        Case_Options_2d.options["FOO00"] = 1 # this turns on the contour of eq_trans
        Case_Options_2d.options["FOO01"] = 1 # this turns on the contour of 725 C
        Case_Options_2d.options["DA_RANGE"] = [-1e8, 1e8] # this turns on the contour of 725 C
        if max_depth == "1500":
            Case_Options_2d.options["MAX_PLOT_DEPTH_IN_SLICE"] = 1500e3 # turn this on to plot max depth of 1500
        # Case_Options_2d.options["FOO02"] = 1 # this turns on the metastable area
        # Case_Options_2d.options["FOO03"] = 1 # this turns on the metastable area in the slab

        # Export paraview script
        odir = os.path.join(local_dir_2d, 'paraview_scripts')
        if not os.path.isdir(odir):
            os.mkdir(odir)
        print("Generating paraview scripts")
        py_script = 'slab1.py'
        ofile = os.path.join(odir, py_script)
        paraview_script = os.path.join(SCRIPT_DIR, 'paraview_scripts', 'ThDSubduction', py_script)
        paraview_script_base = os.path.join(SCRIPT_DIR, 'paraview_scripts', 'base.py')
        Case_Options_2d.read_contents(paraview_script_base, paraview_script)  # combine these two scripts
        Case_Options_2d.substitute()

        ofile_path = Case_Options_2d.save(ofile, relative=True)
        

#### Automazed workflow to finalize visualization

In [None]:
finalize_visual_2d = False

if do_post_process and finalize_visual_2d:

    from hamageolib.research.haoyuan_2d_subduction.workflow_scripts import finalize_visualization_2d_12172024
    from hamageolib.research.haoyuan_3d_subduction.post_process import finalize_visualization_2d_07222025_box

    _time = 9.9e6
    
    # file types
    file_name = "slice_center_viscosity"


    if file_name in ["slice_center_viscosity", "T", "slice_center_density", "slice_center_mow"]:
        if Case_Options_2d.options["GEOMETRY"] == "chunk":
            if max_depth == "1000":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_12172024_trans_modified-01.png"
            else:
                raise NotImplementedError()
            output_image_file = finalize_visualization_2d_12172024(local_dir_2d, file_name, _time, frame_png_file_with_ticks, add_time=False)
        else:
            if max_depth == "1000":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box-01.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, _time, frame_png_file_with_ticks, add_time=False)
            elif max_depth == "1300":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_1300-01.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, _time, frame_png_file_with_ticks, add_time=False, canvas_size=(996, 700))
            elif max_depth == "1500":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_1500.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, _time, frame_png_file_with_ticks, add_time=False, canvas_size=(996, 800), pos_v_diff=90)
            else:
                raise NotImplementedError()

### 3-d case

#### Initiate case options

In [None]:
process_3d_case = False

if do_post_process and process_3d_case:
    from hamageolib.research.haoyuan_3d_subduction.case_options import CASE_OPTIONS

    assert(local_dir is not None)
    
    # case options 
    Case_Options = CASE_OPTIONS(local_dir)
    Case_Options.Interpret()
    Case_Options.SummaryCaseVtuStep(os.path.join(local_dir, "summary.csv"))

In [None]:
if do_post_process and process_3d_case:
    output_dir = os.path.join(img_dir, "runtime_plots")
    
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    
    plot_helper.generate_runtime_plots(Case_Options.time_df, output_dir=output_dir, assemble=True)

#### Generate statistic plots

In [None]:
if do_post_process and process_3d_case:

    file_path = os.path.join(local_dir, "output/statistics")

    output_dir = os.path.join(img_dir, "statistic_plots")

    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    plot_helper.generate_statistic_plots(file_path, output_dir=output_dir, annotate_column="Time step number", assemble=True)

#### Generate plotting scripts

##### Notes

* A purpose to turn on is_process_pyvista_for_plot and run the processing of pyvista file is to apply clip to the domain.

In [None]:
is_prepare_for_plot = False
is_process_pyvista_for_plot = True

if do_post_process and process_3d_case and is_prepare_for_plot:

    from hamageolib.research.haoyuan_3d_subduction.post_process import get_trench_position_from_file, get_slab_depth_from_file,\
          PLOT_CASE_RUN_THD, ProcessVtuFileThDStep

    # options 
    graphical_step = 142
    n_pieces = None # None - process the whole dataset together
                    # 16 - process piecewise

    # parameters
    ofile_list = ["slab1.py"]; require_base=True
    time_range = None
    time_interval = None
    # turn on plot_axis if I want to save a complete result
    # turn off if I want to prepare for figures in a paper
    plot_axis = False
    slices=None # specify steps
    # step = "auto"; slices=3  # auto-figure out the steps, take the numebr of slices
    max_velocity = -1.0  # rescale the color for velocity
    rotation_plus = 0.47 # rotation of the frame along the lon when making plot
    da_range = [-1e8, 1e8] # range of dynamic pressures
    do_clip = True # turn this off to plot the whole mantle (needs to generate new pyvista outputs)

    # Initiate plotting class
    PlotCaseRunThD = PLOT_CASE_RUN_THD(local_dir, time_range=time_range, run_visual=False,\
            time_interval=time_interval, visualization="paraview", step=graphical_step, plot_axis=plot_axis, max_velocity=max_velocity,\
                    rotation_plus=rotation_plus, ofile_list=ofile_list, require_base=require_base)

    # Processing pyvista
    extract_trench_at_additional_depths = [50e3]
    # extract_trench_at_additional_depths = [10e3]
    # extract_trench_at_additional_depths = [0.0]
    # pyvista_outdir = os.path.join(local_dir, "pyvista_outputs", "%05d" % vtu_step)
    if is_process_pyvista_for_plot:
        pvtu_step = graphical_step + int(Case_Options.options['INITIAL_ADAPTIVE_REFINEMENT'])
        ProcessVtuFileThDStep(local_dir, pvtu_step, Case_Options, do_clip=do_clip, extract_trench_at_additional_depths=extract_trench_at_additional_depths,\
                              n_pieces=n_pieces)

    # get initial trench position    
    pyvista_outdir0 = os.path.join(local_dir, "pyvista_outputs", "%05d" % int(Case_Options.options['INITIAL_ADAPTIVE_REFINEMENT']))
    try:
        trench_center_ini = get_trench_position_from_file(pyvista_outdir0, int(Case_Options.options['INITIAL_ADAPTIVE_REFINEMENT']), Case_Options.options['GEOMETRY'])
    except FileNotFoundError:
        trench_center_ini = -1.0

    # Generate paraview script
    # Get time 
    idx = Case_Options.summary_df["Vtu step"] == graphical_step
    _time = Case_Options.summary_df.loc[idx, "Time"].values[0]
    # get trench center
    pvtu_step = graphical_step + int(Case_Options.options['INITIAL_ADAPTIVE_REFINEMENT']) 
    pyvista_outdir = os.path.join(local_dir, "pyvista_outputs", "%05d" % pvtu_step)
    trench_center = get_trench_position_from_file(pyvista_outdir, pvtu_step, Case_Options.options['GEOMETRY'], trench_depth=50e3)
    slab_depth = get_slab_depth_from_file(pyvista_outdir, pvtu_step, Case_Options.options['GEOMETRY'], float(Case_Options.options['OUTER_RADIUS']), "sp_lower")
    # generate paraview script
    Case_Options.options["FOO00"] = 1 # this turns on the contour of eq_trans
    Case_Options.options["FOO01"] = 1 # this turns on the contour of 725 C
    addtional_options = {"TRENCH_CENTER": trench_center, "TRENCH_INI_DERIVED": trench_center_ini, "PLOT_TIME": _time, "DA_RANGE": str(da_range),\
                         "FOO00": 1, "FOO01":1}
    PlotCaseRunThD.GenerateParaviewScript(ofile_list, addtional_options)

#### Automazed workflow to finalize visualization

In [None]:
finalize_visual = False

if do_post_process and process_3d_case and finalize_visual:

    from hamageolib.research.haoyuan_2d_subduction.workflow_scripts import finalize_visualization_2d_12172024
    from hamageolib.research.haoyuan_3d_subduction.post_process import finalize_visualization_2d_07222025_box

    _time = 1.4202e+07
    
    # file types
    file_name = "slice_center_viscosity"

    max_depth = "1000"  # 1000, 1300, or 1500

    if file_name in ["slice_center_viscosity", "T", "slice_center_density", "slice_center_mow"]:
        if Case_Options.options["GEOMETRY"] == "chunk":
            if max_depth == "1000":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_12172024_trans_modified-01.png"
            else:
                raise NotImplementedError()
            output_image_file = finalize_visualization_2d_12172024(local_dir, file_name, _time, frame_png_file_with_ticks, add_time=False)
        else:
            if max_depth == "1000":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box-01.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir, file_name, _time, frame_png_file_with_ticks, add_time=False)
            elif max_depth == "1300":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_1300-01.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir, file_name, _time, frame_png_file_with_ticks, add_time=False, canvas_size=(996, 700))
            elif max_depth == "1500":
                frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_1500.png"
                output_image_file = finalize_visualization_2d_07222025_box(local_dir, file_name, _time, frame_png_file_with_ticks, add_time=False, canvas_size=(996, 800))
            else:
                raise NotImplementedError()

## Analysis

### Rates and MOW area

In [None]:
is_plot_slab_morphology = False

if do_post_process and is_plot_slab_morphology:

    from hamageolib.research.mow_subduction.case_options import CASE_OPTIONS_TWOD, CASE_OPTIONS

    import matplotlib.pyplot as plt
    from matplotlib import gridspec
    from matplotlib.ticker import MultipleLocator
    from matplotlib import rcdefaults

    import hamageolib.utils.plot_helper as plot_helper

    # past options
    # compare cases with/without MOW (reference cases) in 2d
    # dirs_2d = [
    #     os.path.join(local_MOW_dir, "mow01/C_mow_h2890.0_M_gr3_ar4"),
    #     os.path.join(local_MOW_dir, "mow01/C_mow_h2890.0_gr3_ar4")
    #     ]
    # dirs = None
    #
    # Compare 2d and 3d cases, with MOW. 
    # to only plot the 3d cases, set dirs_2d to []
    # dirs_2d = [
    #     os.path.join(local_MOW_dir, "mow01/C_mow_h2890.0_M_gr3_ar4")
    #     os.path.join(local_MOW_dir, "mow01/C_mow_h2890.0_gr3_ar4"),
    # ]

    # Compare the 3d case, with and without a MOW
    # dirs_2d = []
    # dirs = [
    #     os.path.join(local_MOW_dir, "mow3_00/C_mow_h2890.0_M_gr3_ar4"),
    #          "/mnt/lochy/ASPECT_DATA/ThDSubduction/EBA_2d_consistent_8_6/eba3d_width80_c22_AR4_yd300"
    # ]

    # Compare cases with jump at 660 and jump at 1100i
    # Demonstrate the research idea with 7.5 km shear zozne and 1e20 Pas as
    # Reference and vary shear zone width accordinglly. 
    # dirs_2d = [
    #          os.path.join(local_MOW_dir, "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_jp1100_szT7.50"),
    #          os.path.join(local_MOW_dir, "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_jp1100_szV2.00e+20"),
    #         os.path.join(local_MOW_dir, "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_szT7.50"),
    #         os.path.join(local_MOW_dir, "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_szV2.00e+20")
    # ]
    # dirs = []

    # Compare cases with viscosity discontinuity at 1100i
    # With MOW and no MOW
    dirs_2d = [
        os.path.join(local_MOW_dir, "no_mow_sz_jump/C_mow_h2890.0_gr3_ar5_jp1100i_szT7.50_szV5.00e+19"),
        os.path.join(local_MOW_dir, "mow01/C_mow_h2890.0_M_gr3_ar5_jp1100i_szT7.50_szV5.00e+19")
    ]
    dirs = []

    max_slab_depth = 2500e3 # m, only plot to the timestep slab dip reaches this depth
    time_marker = None
    factor_2d = 10
    factor = 10
    odir = os.path.join(local_dir_2d, "img")

    if not os.path.isdir(odir):
        os.mkdir(odir)

    n_2d = len(dirs_2d)
    n_3d = len(dirs)

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5  # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.8
    line_width_scaling_multiplier = 2.0  # extra scaling multiplier for lines
    t_lim = (0.0, 60.0)
    t_tick_interval = 10.0   # tick interval along x
    y_lim = (-5.0, 5.0)
    y_tick_interval = 100.0  # tick interval along y
    v_lim = (-1.5, 1.5)
    v_level = 50  # number of levels in contourf plot
    v_tick_interval = 0.5  # tick interval along v
    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(
        scaling_factor,
        font_scaling_multiplier=font_scaling_multiplier,
        legend_font_scaling_multiplier=legend_font_scaling_multiplier,
        line_width_scaling_multiplier=line_width_scaling_multiplier
    )

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Initiate figure
    # ax - slab dip position and trench position
    # ax_twin - slab depth
    # ax1 - kinetics
    # ax1_twinx - dip angle
    # ax2 - plate and sinking velocity
    # ax2_twinx - MOW area
    # ax3 - 3d model
    # ax3_twinx - MOW volume
    fig = plt.figure(figsize=(10*scaling_factor, 7*scaling_factor), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)
    
    ax = fig.add_subplot(gs[0, 0])
    ax_twin = ax.twinx()
    ax1 = fig.add_subplot(gs[0, 1])
    ax1_twinx = ax1.twinx()
    ax2 = fig.add_subplot(gs[1, 0])
    ax2_twinx = ax2.twinx()
    ax3 = fig.add_subplot(gs[1, 1])
    ax3_twinx = ax3.twinx()

    # Initiate case options
    # Loop for 2d cases
    for i, _dir_2d in enumerate(dirs_2d):
        Case_Options_2d = CASE_OPTIONS_TWOD(_dir_2d)
        Case_Options_2d.Interpret()
        geometry = Case_Options_2d.options["GEOMETRY"]
        Ro = Case_Options_2d.options["OUTER_RADIUS"]
        Case_Options_2d.SummaryCaseVtuStep(os.path.join(_dir_2d, "summary.csv"))

        # Get plot values
        time_2d_raw = Case_Options_2d.summary_df["Time"].to_numpy()
        trench_center_2d_raw = Case_Options_2d.summary_df["Trench"].to_numpy()
        slab_depth_2d_raw = Case_Options_2d.summary_df["Slab depth"].to_numpy()
        dip_angle_2d_raw = Case_Options_2d.summary_df["Dip 100"].to_numpy()
        mow_area_2d_raw = Case_Options_2d.summary_df["Mow area"].to_numpy()
        mow_area_slab_2d_raw = Case_Options_2d.summary_df["Mow area cold"].to_numpy()
        sp_velocity_2d_raw = Case_Options_2d.summary_df["Sp velocity"].to_numpy()

        # mask on slab depth 
        mask0 = (slab_depth_2d_raw < max_slab_depth)
        time_2d = time_2d_raw[mask0]
        trench_center_2d = trench_center_2d_raw[mask0]
        slab_depth_2d = slab_depth_2d_raw[mask0]
        dip_angle_2d = dip_angle_2d_raw[mask0]
        mow_area_2d = mow_area_2d_raw[mask0]
        mow_area_slab_2d = mow_area_slab_2d_raw[mask0]
        sp_velocity_2d = sp_velocity_2d_raw[mask0]

        if geometry == "chunk":
            trench_center_2d *= Ro

        # plot slab dip angle and trench position

        Xs_2d = time_2d / 1e6
        Ys_2d = (trench_center_2d - trench_center_2d[0]) / 1e3
        Ys_2d_1 = slab_depth_2d / 1e3
        dx_dy_2d = np.gradient(Ys_2d[::factor_2d], Xs_2d[::factor_2d]) / 1e3 * 1e2
        dx_dy_2d_1 = np.gradient(Ys_2d_1[::factor_2d], Xs_2d[::factor_2d]) / 1e3 * 1e2
        ax.plot(Xs_2d[::factor_2d], Ys_2d[::factor_2d], label="Trench 2d", color=default_colors[i])
        if i==0 and time_marker is not None:
            ax.vlines(time_marker/1e6, linestyle="--", color="k", ymin=-150.0, ymax=100.0, linewidth=1)
        ax_twin.plot(Xs_2d[::factor_2d], Ys_2d_1[::factor_2d], linestyle="-.", label="Slab Depth 2d", color=default_colors[i])

        if i == 0:
            lines, labels = ax.get_legend_handles_labels()
            lines2, labels2 = ax_twin.get_legend_handles_labels()
            ax.legend(lines + lines2, labels + labels2, loc="upper right")

        # plot velocity and dip angle

        ax1.plot(Xs_2d[::factor_2d], dx_dy_2d, label="Trench Velocity 2d", color=default_colors[i])
        ax1.plot(Xs_2d[::factor_2d], sp_velocity_2d[::factor_2d]*100.0, label="Sp Velocity 2d", color=default_colors[i], linewidth=3)
        ax1.plot(Xs_2d[::factor_2d], dx_dy_2d_1, linestyle="-.", label="Sinking Velocity 2d", color=default_colors[i])
        
        ax1_twinx.plot(Xs_2d[::factor_2d], dip_angle_2d[::factor_2d]*180.0/np.pi, label="Dip 100 2d", linestyle="--", color=default_colors[i])

        if i == 0:
            lines, labels = ax1.get_legend_handles_labels()
            lines2, labels2 = ax1_twinx.get_legend_handles_labels()
            ax1.legend(lines + lines2, labels + labels2, loc="upper right")
        
        # plot velocity and mow area
        ax2.plot(Xs_2d[::factor_2d], dx_dy_2d_1, linestyle="-.", label="Sinking Velocity 2d", color=default_colors[i])
        ax2.plot(Xs_2d[::factor_2d], sp_velocity_2d[::factor_2d]*100.0, label="Sp Velocity 2d", color=default_colors[i], linewidth=3)
        
        ax2_twinx.plot(Xs_2d[::factor_2d], mow_area_2d[::factor_2d]/1e6, label="MOW Area, 2d", linestyle="-", color=default_colors[i], linewidth=1)
        ax2_twinx.plot(Xs_2d[::factor_2d], mow_area_slab_2d[::factor_2d]/1e6, label="MOW Area in slab, 2d", linestyle="-", color=default_colors[i], linewidth=2)

        if i == 0:
            lines, labels = ax2.get_legend_handles_labels()
            lines2, labels2 = ax2_twinx.get_legend_handles_labels()
            ax2.legend(lines + lines2, labels + labels2, loc="upper right")


    # Loop for 3d cases
    for i, _dir in enumerate(dirs):
        Case_Options = CASE_OPTIONS(_dir)
        Case_Options.Interpret()
        geometry = Case_Options.options["GEOMETRY"]
        Ro = Case_Options.options["OUTER_RADIUS"]
        Case_Options.SummaryCaseVtuStep(os.path.join(_dir, "summary.csv"))

        # Get plot values
        time_raw = Case_Options.summary_df["Time"].to_numpy()
        trench_center_raw = Case_Options.summary_df["Trench (center)"].to_numpy()
        slab_depth_raw = Case_Options.summary_df["Slab depth"].to_numpy()
        dip_angle_raw = Case_Options.summary_df["Dip 100 (center)"].to_numpy()
        mow_volume_raw = Case_Options.summary_df["MOW volume"].to_numpy()
        mow_volume_slab_raw = Case_Options.summary_df["MOW volume cold"].to_numpy()
        mow_area_raw = Case_Options.summary_df["Mow area center"].to_numpy()
        mow_area_slab_raw = Case_Options.summary_df["Mow area cold center"].to_numpy()
        sp_velocity_raw = Case_Options.summary_df["Sp velocity"].to_numpy()

        # mask on slab depth 
        mask0 = (slab_depth_raw < max_slab_depth)
        _time = time_raw[mask0]
        trench_center = trench_center_raw[mask0]
        slab_depth = slab_depth_raw[mask0]
        dip_angle = dip_angle_raw[mask0]
        mow_area = mow_area_raw[mask0]
        mow_area_slab = mow_area_slab_raw[mask0]
        mow_volume = mow_volume_raw[mask0]
        mow_volume_slab = mow_volume_slab_raw[mask0]
        sp_velocity = sp_velocity_raw[mask0]

        if geometry == "chunk":
            trench_center *= Ro

        # plot slab dip angle and trench position
        Xs = _time / 1e6
        Ys = (trench_center - trench_center[0]) / 1e3
        Ys_1 = slab_depth / 1e3
        dx_dy = np.gradient(Ys[::factor], Xs[::factor]) / 1e3 * 1e2
        dx_dy_1 = np.gradient(Ys_1[::factor], Xs[::factor]) / 1e3 * 1e2
        ax.plot(Xs[::factor], Ys[::factor], label="Trench", color=default_colors[i+n_2d])
        if i==0 and time_marker is not None:
            ax.vlines(time_marker/1e6, linestyle="--", color="k", ymin=-150.0, ymax=100.0, linewidth=1)
        ax_twin.plot(Xs[::factor], Ys_1[::factor], linestyle="-.", label="Slab Depth", color=default_colors[i+n_2d])

        if i == 0:
            lines, labels = ax.get_legend_handles_labels()
            lines2, labels2 = ax_twin.get_legend_handles_labels()
            ax.legend(lines + lines2, labels + labels2, loc="upper right")

        # plot velocity and dip angle

        ax1.plot(Xs[::factor], dx_dy, label="Trench Velocity", color=default_colors[i+n_2d])
        ax1.plot(Xs[::factor], sp_velocity[::factor]*100.0, label="Sp Velocity", color=default_colors[i+n_2d], linewidth=3)
        ax1.plot(Xs[::factor], dx_dy_1, linestyle="-.", label="Sinking Velocity", color=default_colors[i+n_2d])
        
        ax1_twinx.plot(Xs[::factor], dip_angle[::factor]*180.0/np.pi, label="Dip 100", linestyle="--", color=default_colors[i+n_2d])

        # plot velocity and mow area

        ax2.plot(Xs[::factor], dx_dy_1, linestyle="-.", label="Sinking Velocity", color=default_colors[i+n_2d])
        ax2.plot(Xs[::factor], sp_velocity[::factor]*100.0, label="Sp Velocity", color=default_colors[i+n_2d], linewidth=3)
        
        ax2_twinx.plot(Xs[::factor], mow_area_slab[::factor]/1e6, label="MOW Area in slab", linestyle="-", color=default_colors[i+n_2d], linewidth=2)

        # plot velocity and mow volume 
        ax3.plot(Xs[::factor], dx_dy_1, linestyle="-.", label="Sinking Velocity", color=default_colors[i+n_2d])
        ax3.plot(Xs[::factor], sp_velocity[::factor]*100.0, label="Sp Velocity", color=default_colors[i+n_2d], linewidth=3)
        
        ax3_twinx.plot(Xs[::factor], mow_volume_slab[::factor]/1e9, label="MOW Volume in slab", linestyle="-", color=default_colors[i+n_2d], linewidth=2)


    # configuration of figures
    ax.set_xlim(t_lim)
    ax.set_ylim([-800.0, 200.0])
    ax_twin.set_ylim([0, 2500.0])
    ax.set_xlabel("Time (Ma)")
    ax.set_ylabel("Trench (km)")
    ax.grid()

    for spine in ax.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax_twin.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    ax.xaxis.set_major_locator(MultipleLocator(t_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(t_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(200.0))
    ax.yaxis.set_minor_locator(MultipleLocator(200.0/(n_minor_ticks+1)))
    ax_twin.yaxis.set_major_locator(MultipleLocator(500.0))
    ax_twin.yaxis.set_minor_locator(MultipleLocator(500.0/(n_minor_ticks+1)))

    ax1.set_xlim(t_lim)
    ax1.set_ylim([-5.0, 20.0])
    ax1.set_xlabel("Time (Ma)")
    ax1.set_ylabel("Velocity (cm/yr)")
    ax1.grid()

    ax1.xaxis.set_major_locator(MultipleLocator(t_tick_interval))
    ax1.xaxis.set_minor_locator(MultipleLocator(t_tick_interval/(n_minor_ticks+1)))
    ax1.yaxis.set_major_locator(MultipleLocator(5.0))
    ax1.yaxis.set_minor_locator(MultipleLocator(5.0/(n_minor_ticks+1)))

    ax1_twinx.set_ylim([20.0, 70.0])
    ax1_twinx.yaxis.set_major_locator(MultipleLocator(10.0))
    ax1_twinx.yaxis.set_minor_locator(MultipleLocator(10.0/(n_minor_ticks+1)))
    
    for spine in ax1.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax1_twinx.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    ax2.set_xlim(t_lim)
    ax2.set_ylim([0.0, 20.0])
    ax2.set_xlabel("Time (Ma)")
    ax2.set_ylabel("Velocity (cm/yr)")
    ax2.grid()
    
    ax2.xaxis.set_major_locator(MultipleLocator(t_tick_interval))
    ax2.xaxis.set_minor_locator(MultipleLocator(t_tick_interval/(n_minor_ticks+1)))
    ax2.yaxis.set_major_locator(MultipleLocator(5.0))
    ax2.yaxis.set_minor_locator(MultipleLocator(5.0/(n_minor_ticks+1)))

    ax2_twinx.set_ylabel("MOW Area (km^2)")
    ax2_twinx.set_ylim([0.0, 8000.0])
    ax2_twinx.yaxis.set_major_locator(MultipleLocator(2000.0))
    ax2_twinx.yaxis.set_minor_locator(MultipleLocator(2000.0/(n_minor_ticks+1)))
    
    for spine in ax2.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax2_twinx.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    ax3.set_xlim(t_lim)
    ax3.set_ylim([0.0, 20.0])
    ax3.set_xlabel("Time (Ma)")
    ax3.set_ylabel("Velocity (cm/yr)")
    ax3.grid()
    
    ax3.xaxis.set_major_locator(MultipleLocator(t_tick_interval))
    ax3.xaxis.set_minor_locator(MultipleLocator(t_tick_interval/(n_minor_ticks+1)))
    ax3.yaxis.set_major_locator(MultipleLocator(5.0))
    ax3.yaxis.set_minor_locator(MultipleLocator(5.0/(n_minor_ticks+1)))

    ax3_twinx.set_ylabel("MOW Volume (km^3)")
    ax3_twinx.set_ylim([0.0, 8e6])
    ax3_twinx.yaxis.set_major_locator(MultipleLocator(2e6))
    ax3_twinx.yaxis.set_minor_locator(MultipleLocator(2e6/(n_minor_ticks+1)))
    
    for spine in ax3.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)
    for spine in ax3_twinx.spines.values():
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # save figure
    filepath = os.path.join(odir, "slab_morphology.pdf")
    fig.savefig(filepath)
    print("Saved figure: ", filepath)
    filepath_png = os.path.join(odir, "slab_morphology.png")
    fig.savefig(filepath_png)
    print("Saved figure: ", filepath_png)

## Diagram

### 2-d Case

#### Read data
In this block:
  * Use vtk package to read from vtu files (need to set the snapshot, e.g. 49)
  * Convert data (coordinates and fields) to numpy array
  * initiate interpolators (Here we give the example of T, P and resolution, etc)

In [None]:
check_diagram_2d = False

if do_post_process and check_diagram_2d:

    # Options:
    # step - the step to plot (step where visualization is generated)
    step = 100

    # Import module
    import vtk
    from vtk.util.numpy_support import vtk_to_numpy
    from hamageolib.utils.vtk_utilities import calculate_resolution
    import time
    from scipy.interpolate import LinearNDInterpolator
    from hamageolib.research.mow_subduction.case_options import CASE_OPTIONS_TWOD

    # Case options
    Case_Options_2d = CASE_OPTIONS_TWOD(local_dir_2d)
    Case_Options_2d.Interpret()
    Case_Options_2d.SummaryCaseVtuStep(os.path.join(local_dir_2d, "summary.csv"))

    # Find vtu file
    pvtu_step = step + int(Case_Options_2d.options['INITIAL_ADAPTIVE_REFINEMENT'])
    pvtu_file = os.path.join(local_dir_2d, "output", "solution", "solution-%05d.pvtu" % pvtu_step)
    assert(os.path.isfile(pvtu_file))

    # Read the pvtu file
    start = time.time()

    reader = vtk.vtkXMLPUnstructuredGridReader()
    reader.SetFileName(pvtu_file)
    reader.Update()

    end = time.time()
    print("Initiating reader takes %.2e s" % (end - start))
    start = end

    # Get the output data from the reader
    grid = reader.GetOutput()  # Access the unstructured grid
    data_set = reader.GetOutputAsDataSet()  # Access the dataset representation
    points = grid.GetPoints()  # Extract the points (coordinates)
    cells = grid.GetCells()  # Extract the cell connectivity information
    point_data = data_set.GetPointData()  # Access point-wise data

    n_points = grid.GetNumberOfPoints() # Number of points and cells
    n_cells = grid.GetNumberOfCells()

    end = time.time()
    print("Reading files takes %.2e s" % (end - start))
    print(f"\tNumber of points: {n_points}")
    print(f"\tNumber of cells: {n_cells}")
    print("\tAvailable point data fields:")
    for i in range(point_data.GetNumberOfArrays()):
        # Field names in point data
        name = point_data.GetArrayName(i)
        print(f"\t  - {name}")
    start = end

    # Convert data to numpy array
    # Get coordinates (points)
    # Get field "T"

    vtk_points = grid.GetPoints().GetData()
    points_np = vtk_to_numpy(vtk_points)  # Shape: (n_points, 3)
    points_2d = points_np[:, :2]  # Use only the first two columns for 2D coordinates

    # Initialize dictionary for interpolators
    interpolators = {}

    # Loop over all arrays in point data
    num_arrays = point_data.GetNumberOfArrays()
    for i in range(num_arrays):
        array_name = point_data.GetArrayName(i)
        vtk_array = point_data.GetArray(i)
        
        if vtk_array is None:
            print(f"Warning: Array {array_name} is None, skipping.")
            continue
        
        # Convert VTK array to NumPy
        np_array = vtk_to_numpy(vtk_array)
        
        # Create interpolator and add to dict
        interpolators[array_name] = LinearNDInterpolator(points_2d, np_array, fill_value=np.nan)

    # Calculate resolution for each cell or point in the grid
    resolution_np = calculate_resolution(grid)  # Custom function (not defined here)

    end = time.time()
    print("Calculating resolution takes %.2e s" % (end - start))
    start = end

    # Create interpolators for temperature, pressure, and resolution
    interpolators["resolution"] = LinearNDInterpolator(points_2d, resolution_np)  # Interpolator for resolution

    end = time.time()
    print("Construct linear ND interpolator takes %.2e s" % (end - start))
    start = end

#### Generate grid

Next use the interpolator we have to generate a grid to plot

- Note the interval is defined by meter
- xs and ys are generate with slightly different interval, therefore we always get different number of nodes along x and y, making it easier to debug (It's generally easier when a 2-d array have different sizes along the 2 dimensions, that you can easier tell which is which.)

In [None]:
if do_post_process and check_diagram_2d:

    start = time.time()

    # Define the interval for the grid (in meters)
    interval = 10e3

    # Determine the bounding box of the 2D points
    x_min, y_min = np.min(points_2d, axis=0)
    x_max, y_max = np.max(points_2d, axis=0)

    # Define a regular grid within the bounding box
    # allow a little different in interval in x
    # and y axis, thereform making the two dimensions
    # unequal to make fewer mistakes ...
    xs = np.arange(x_min, x_max, interval*0.99)
    ys = np.arange(y_min, y_max, interval*1.01)
    x_grid, y_grid = np.meshgrid(xs, ys, indexing="ij")  # Create a grid of (x, y) points

    # Flatten the grid for interpolation
    grid_points_2d = np.vstack([x_grid.ravel(), y_grid.ravel()]).T

    # Interpolate temperature (T) values onto the regular grid
    T_grid = interpolators["T"](grid_points_2d)  # Use the NearestNDInterpolator
    T_grid = T_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate temperature (P) values onto the regular grid
    P_grid = interpolators["p"](grid_points_2d)  # Use the NearestNDInterpolator
    P_grid = P_grid.reshape(x_grid.shape)  # Reshape back to the grid
    
    # Interpolate density (density) values onto the regular grid
    density_grid = interpolators["density"](grid_points_2d)  # Use the NearestNDInterpolator
    density_grid = density_grid.reshape(x_grid.shape)  # Reshape back to the grid
    
    # Interpolate metastable (metastable) values onto the regular grid
    metastable_grid = interpolators["metastable"](grid_points_2d)  # Use the NearestNDInterpolator
    metastable_grid = metastable_grid.reshape(x_grid.shape)  # Reshape back to the grid

    # Interpolate resolutions onto the regular grid
    resolutions_grid = interpolators["resolution"](grid_points_2d)
    resolutions_grid = resolutions_grid.reshape(x_grid.shape)

    end = time.time()
    print("Interpolating to regular grid takes %.2e s" % (end - start))
    print("\tgrid shape: (x axis, y axis): ", x_grid.shape)
    start = end

#### Interp the results into a regular grid

In [None]:
if do_post_process and check_diagram_2d:

    # Load modules
    from scipy.spatial import cKDTree

    # Option
    max_distance = 0.1  

    # Get the P, T limits
    T_min = np.min(T_grid)
    T_max = np.max(T_grid)
    P_min = np.min(P_grid)
    P_max = np.max(P_grid)

    # 1. Flatten input data
    data_points = np.column_stack((T_grid.ravel()/T_max, P_grid.ravel()/P_max))
    tree = cKDTree(data_points)
    metastable_values = metastable_grid.ravel()

    # 2. Create a regular grid
    T_lin = np.linspace(T_min/T_max, 1.0, 300)
    P_lin = np.linspace(P_min/P_max, 1.0, 300)
    T_reg, P_reg = np.meshgrid(T_lin, P_lin)
    grid_points = np.column_stack((T_reg.ravel(), P_reg.ravel()))

    # 3. Interpolate the data
    # Nearest neighbor interpolation
    distances, indices = tree.query(grid_points, k=1)
    mask = distances <= max_distance
    metastable_interp_flat = np.full(grid_points.shape[0], np.nan)
    metastable_interp_flat[mask] = metastable_values[indices[mask]]
    metastable_interp = metastable_interp_flat.reshape(T_reg.shape)

#### Plot the P, T diagram

In [None]:
if do_post_process and check_diagram_2d:

    # Load modules
    from cmcrameri import cm as ccm

    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # Example usage
    # Rule of thumbs:
    # 1. Set the limit to something like 5.0, 10.0 or 50.0, 100.0 
    # 2. Set five major ticks for each axis
    scaling_factor = 1.0  # scale factor of plot
    font_scaling_multiplier = 1.5 # extra scaling multiplier for font
    legend_font_scaling_multiplier = 0.5
    line_width_scaling_multiplier = 2.0 # extra scaling multiplier for lines

    T_lim = (400.0, 1800.0) # T (K)
    T_level = 50  # number of levels in contourf plot
    T_tick_interval = 200.0  # tick interval along v

    T_lim1 = (0.0, 800.0) # T (C), smaller scale
    T_tick_interval1 = 200.0  # tick interval along x

    P_lim = (10.0, 30.0) # P (Gpa)
    P_level = 50  # number of levels in contourf plot
    P_tick_interval = 5.0  # tick interval along v

    density_lim = (3000.0, 4000.0)
    density_level = 50  # number of levels in contourf plot
    density_tick_interval = 100.0  # tick interval along P

    metastable_lim = (0.0, 1.0) # metastable contents
    metastable_level = 100
    metastable_interval = 0.2

    n_minor_ticks = 4  # number of minor ticks between two major ones

    # scale the matplotlib params
    plot_helper.scale_matplotlib_params(scaling_factor, font_scaling_multiplier=font_scaling_multiplier,\
                            legend_font_scaling_multiplier=legend_font_scaling_multiplier,
                            line_width_scaling_multiplier=line_width_scaling_multiplier)

    # Update font settings for compatibility with publishing tools like Illustrator.
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

    # Create a figure
    fig = plt.figure(figsize=(10, 8), tight_layout=True)
    gs = gridspec.GridSpec(2, 2)

    # Plot the diagram of metastable composition
    ax = fig.add_subplot(gs[0, 0])

    ticks=np.arange(metastable_lim[0], metastable_lim[1], metastable_interval)

    color_map = ax.pcolormesh(T_grid, P_grid/1e9, metastable_grid,\
                             vmin=metastable_lim[0], vmax=metastable_lim[1], cmap="viridis")

    cbar = fig.colorbar(color_map, ax=ax, label="Metastable")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(T_lim)
    ax.set_ylim(P_lim)

    ax.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(P_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.invert_yaxis()

    ax.set_xlabel("T (K)")
    ax.set_ylabel("P (GPa)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot the diagram of density
    ax = fig.add_subplot(gs[0, 1])

    ticks=np.arange(density_lim[0], density_lim[1], density_tick_interval)

    color_map = ax.pcolormesh(T_grid, P_grid/1e9, density_grid,\
                            vmin=density_lim[0], vmax=density_lim[1], cmap=ccm.batlow)

    cbar = fig.colorbar(color_map, ax=ax, label="Density")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(T_lim)
    ax.set_ylim(P_lim)

    ax.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(P_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.invert_yaxis()

    ax.set_xlabel("T (K)")
    ax.set_ylabel("P (GPa)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # Plot the diagram of interpolated metastable composition
    ax = fig.add_subplot(gs[1, 0])

    ticks=np.arange(metastable_lim[0], metastable_lim[1], metastable_interval)

    color_map = ax.pcolormesh(T_reg*T_max, P_reg*P_max/1e9, metastable_interp,\
                             vmin=metastable_lim[0], vmax=metastable_lim[1], cmap="viridis")

    cbar = fig.colorbar(color_map, ax=ax, label="Metastable")  # Add colorbar
    cbar.set_ticks(ticks)

    ax.set_xlim(T_lim)
    ax.set_ylim(P_lim)

    ax.xaxis.set_major_locator(MultipleLocator(T_tick_interval))
    ax.xaxis.set_minor_locator(MultipleLocator(T_tick_interval/(n_minor_ticks+1)))
    ax.yaxis.set_major_locator(MultipleLocator(P_tick_interval))
    ax.yaxis.set_minor_locator(MultipleLocator(P_tick_interval/(n_minor_ticks+1)))

    ax.grid()

    ax.invert_yaxis()

    ax.set_xlabel("T (K)")
    ax.set_ylabel("P (GPa)")

    for spine in ax.spines.values():
        # Adjust spine thickness for this plot
        spine.set_linewidth(0.5 * scaling_factor * line_width_scaling_multiplier)

    # ofile = os.path.join(case_dir, "metastable_diagram_nrep_%d_vstep_%05d.pdf" % (n_repetition, vtu_step))
    # fig.savefig(ofile)
    # print("saved figure %s" % ofile)
    
    # Reset rcParams to defaults
    rcdefaults()

## Animation

### 2-d case, basic

Note: Here we need to first generate scripts for making plot. Then we run them in terminals. It's only all these figures are generated that the final animation could be assembled. In practice, we need to run this section for a couple of times.

List of task list
- "TRENCH_CENTER": modify with real trench location, so that the triangle will show up at the right location
- Finalizing results from paraview: use frames with ticks
- Assemble animation: assemble colorbars for plots as well

In [None]:
# Options
# animate_2d_case_basic - top-level control
# debug_step0_animate_2d_case_basic - only run step 0 to debug
# generate_paraview_scripts_for_animate_2d_case_basic - generate paraview scripts stepwise
animate_2d_case_basic = True
debug_step0_animate_2d_case_basic = False
generate_paraview_scripts_for_animate_2d_case_basic = True

if animate_2d_case_basic:
    
    from hamageolib.research.mow_subduction.case_options import CASE_OPTIONS_TWOD
    from hamageolib.research.haoyuan_3d_subduction.post_process import ProcessVtuFileTwoDStep

    # Assign a time interval for animation
    time_interval = 0.5e6
    animation_name= "ani_basic"
    max_depth = "1500"

    # Apply case options
    Case_Options_2d = CASE_OPTIONS_TWOD(local_dir_2d)
    Case_Options_2d.Interpret()
    Case_Options_2d.SummaryCaseVtuStep(os.path.join(local_dir_2d, "summary.csv"))
    Case_Options_2d.SummaryCaseVtuStepExport(os.path.join(local_dir_2d, "summary.csv"))
    resampled_df = Case_Options_2d.resample_visualization_df(time_interval)
    graphical_steps = resampled_df["Vtu step"].values

#### Generate paraview scripts stepwise

This is an optional step to generate stepwise paraview scripts (controlled by generate_paraview_scripts_for_animate_2d_case_basic).
After this step: run the py_temp_foo.py in a terminal to generate visualization.

In [None]:
# Loop the time steps to get things done 
if animate_2d_case_basic and generate_paraview_scripts_for_animate_2d_case_basic: 

    # Open py_temp_file for output
    fout = open(py_temp_file, 'w')
    assert(fout)
    fout.write("#!/bin/bash\n")

    # Run stepwise
    print("Start generating paraview scripts")
    for i, _time in enumerate(resampled_df["Time"].values):

        # debug run step 0
        if debug_step0_animate_2d_case_basic:
            if i > 0:
                break

        # Stepwise configurations 
        _time = resampled_df["Time"].values[i]
        time_rounded = round(_time / float(resampled_df.attrs["Time between graphical output"]))\
              * float(resampled_df.attrs["Time between graphical output"])
        step = graphical_steps[i]
        print("\tGenerating paraview scripts for step = %d, time = %.4e" % (step, time_rounded))

        # Assign the script to use
        py_script = "slab1.py"

        # Make the directory to hold the scripts
        ps_dir = os.path.join(local_dir_2d, 'paraview_scripts')
        if not os.path.isdir(ps_dir):
            os.mkdir(ps_dir) 
        odir = os.path.join(ps_dir, "stepwise")
        if not os.path.isdir(odir):
            os.mkdir(odir)

        # Apply stepwise configuration
        Case_Options_2d.options['GRAPHICAL_STEPS'] = [step]
        Case_Options_2d.options['GRAPHICAL_TIMES'] = [time_rounded]
        Case_Options_2d.options["TRENCH_CENTER"] = -1.0 # modify with real trench location
        Case_Options_2d.options["FOO00"] = 1 # this turns on the contour of eq_trans
        Case_Options_2d.options["FOO01"] = 1 # this turns on the contour of 725 C
        Case_Options_2d.options["FOO02"] = 1 # this turns on the metastable area
        Case_Options_2d.options["FOO03"] = 1 # this turns on the metastable area in the slab
        Case_Options_2d.options["DA_RANGE"] = "[-1e8, 1e8]"
        if max_depth == "1500":
            Case_Options_2d.options["MAX_PLOT_DEPTH_IN_SLICE"] = 1500e3
        ofile = os.path.join(odir, 'slab_%d.py' % (step))
        paraview_script = os.path.join(SCRIPT_DIR, 'paraview_scripts', 'ThDSubduction', py_script)
        paraview_script_base = os.path.join(SCRIPT_DIR, 'paraview_scripts', 'base.py')
        Case_Options_2d.read_contents(paraview_script_base, paraview_script)

        # Save script
        Case_Options_2d.substitute()
        Case_Options_2d.save(ofile)

        # Write to py_temp file
        fout.write("pvpython %s\n" % ofile)

    # Finish writting to py_temp file
    fout.close()
    subprocess.run(["chmod", "+x", py_temp_file])
    print("saved file: %s" % py_temp_file)

#### Finalize plot from paraview

Note that the previous step needs to be done

In [None]:
if animate_2d_case_basic:


    from hamageolib.research.haoyuan_2d_subduction.workflow_scripts import finalize_visualization_2d_12172024
    from hamageolib.research.haoyuan_3d_subduction.post_process import finalize_visualization_2d_07222025_box

    # file types
    file_name_list = ["slice_center_viscosity", "slice_center_temperature", "slice_center_density"]
    if Case_Options_2d.options["MODEL_TYPE"] == "mow":
        file_name_list += ["slice_center_mow"]

    print("Start Finalizing Plots")
    for i, _time in enumerate(resampled_df["Time"].values):

        # debug run step 0
        if debug_step0_animate_2d_case_basic:
            if i > 0:
                break

        # Stepwise configurations 
        _time = resampled_df["Time"].values[i]
        time_rounded = round(_time / float(resampled_df.attrs["Time between graphical output"]))\
              * float(resampled_df.attrs["Time between graphical output"])
        step = graphical_steps[i]
        print("\tFinalizing plots for step = %d, time = %.4e" % (step, time_rounded))

        for file_name in file_name_list: 
            if Case_Options_2d.options["GEOMETRY"] == "chunk":
                if max_depth == "1000":
                    frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_12172024_trans_modified-01.png"
                    output_image_file = finalize_visualization_2d_12172024(local_dir_2d, file_name, time_rounded, frame_png_file_with_ticks, add_time=False)
                else:
                    raise NotImplementedError()
            else:
                if max_depth == "1000":
                    frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_with_frame.png"
                    output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, time_rounded, frame_png_file_with_ticks, add_time=False, canvas_size=(1040, 610))
                elif max_depth == "1300":
                    frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_1300-01.png"
                    output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, time_rounded, frame_png_file_with_ticks, add_time=False, canvas_size=(1040, 610))
                elif max_depth == "1500":
                    frame_png_file_with_ticks = "/home/lochy/Documents/papers/documented_files/ThDSubduction/Frame/upper_mantle_frame_07222025_trans_modified_box_with_frame_1500.png"
                    output_image_file = finalize_visualization_2d_07222025_box(local_dir_2d, file_name, time_rounded, frame_png_file_with_ticks, add_time=False, canvas_size=(1040, 800),\
                                                                               pos_v_diff=90)
                else:
                    raise NotImplementedError()

#### Assemble and make animation

In the following blocks, we take the figure we produced (i.e. Linear plots, finanlized figure from paraview) and:
1. We assemble them stepwise and generate one combined figure fore each step
2. We create an avi file from these figures

In [None]:
if animate_2d_case_basic:

    print("Start making animation")

    # Load modules
    from hamageolib.research.haoyuan_2d_subduction.workflow_scripts import create_avi_from_images

    # Initiation
    prep_dir = os.path.join(local_dir_2d, "img", "prep")
    ani_file_paths = [] # path of the figures

    # Loop the steps to get job done
    for i, _time in enumerate(resampled_df["Time"].values):

        # debug run step 0
        if debug_step0_animate_2d_case_basic:
            if i > 0:
                break

        # do this if there is error at the last step
        if i == resampled_df["Time"].values.size - 1:
            break

        # Stepwise configurations 
        _time = resampled_df["Time"].values[i]
        time_rounded = round(_time / float(resampled_df.attrs["Time between graphical output"]))\
              * float(resampled_df.attrs["Time between graphical output"])
        step = graphical_steps[i]
        print("\tAssembling plots for step = %d, time = %.4e" % (step, time_rounded))

        # File paths
        image_files = []; image_positions=[]; cropping_regions=[]; image_scale_factors=[]
        # 0: viscosity slice
        file_path_0 = os.path.join(prep_dir, "%s_t%.4e.png" % (file_name_list[0], time_rounded))
        assert(os.path.isfile(file_path_0))
        image_files.append(file_path_0)
        image_positions.append((0, 100)) 
        cropping_regions.append(None)
        image_scale_factors.append(0.9)
        # 0: viscosity colorbar
        file_path_0_c = "/home/lochy/Documents/papers/documented_files/MOW/paper_dynamics/color_viscosity_19_24-01.png"
        assert(os.path.isfile(file_path_0_c))
        image_files.append(file_path_0_c)
        image_positions.append((100, 550)) 
        cropping_regions.append(None)
        image_scale_factors.append(1.5)
        # 1: metastable
        file_path_1 = None
        if Case_Options_2d.options["MODEL_TYPE"] == "mow":
            file_path_1 = os.path.join(prep_dir, "%s_t%.4e.png" % (file_name_list[-1], time_rounded))
            assert(os.path.isfile(file_path_1))
            image_files.append(file_path_1)
            image_positions.append((1000, 100)) 
            cropping_regions.append(None)
            image_scale_factors.append(0.9)
            # 1: metastabl colorbar
            file_path_1c = "/home/lochy/Documents/papers/documented_files/MOW/paper_dynamics/color_metastable_0_1-01.png"
            assert(os.path.isfile(file_path_1c))
            image_files.append(file_path_1c)
            image_positions.append((1100, 550)) 
            cropping_regions.append(None)
            image_scale_factors.append(1.5)
        # 2: temperature slice
        file_path_2 = os.path.join(prep_dir, "%s_t%.4e.png" % (file_name_list[1], time_rounded))
        assert(os.path.isfile(file_path_2))
        image_files.append(file_path_2)
        image_positions.append((0, 800)) 
        cropping_regions.append(None)
        image_scale_factors.append(0.9)
        # 2: temperature colorbar
        file_path_2_c = "/home/lochy/Documents/papers/documented_files/MOW/paper_dynamics/color_temperature_0_2000-01.png"
        assert(os.path.isfile(file_path_2_c))
        image_files.append(file_path_2_c)
        image_positions.append((100, 1250)) 
        cropping_regions.append(None)
        image_scale_factors.append(1.5)
        # 3: density slice
        file_path_3 = os.path.join(prep_dir, "%s_t%.4e.png" % (file_name_list[2], time_rounded))
        assert(os.path.isfile(file_path_3))
        image_files.append(file_path_3)
        image_positions.append((1000, 800)) 
        cropping_regions.append(None)
        image_scale_factors.append(0.9)
        # 3: density colorbar
        file_path_3_c = "/home/lochy/Documents/papers/documented_files/MOW/paper_dynamics/color_density_3000_4000-01.png"
        assert(os.path.isfile(file_path_3_c))
        image_files.append(file_path_3_c)
        image_positions.append((1100, 1250)) 
        cropping_regions.append(None)
        image_scale_factors.append(1.5)

        # Combine images
        output_image_file = os.path.join(prep_dir, "%s_t%.4e.png" % (animation_name, _time))
        # Remove existing output image to ensure a clean overlay
        if os.path.isfile(output_image_file):
            os.remove(output_image_file)
        # Call overlay function
        plot_helper.overlay_images_on_blank_canvas(
            canvas_size=(2000, 1500),  # Size of the blank canvas in pixels (width, height)
            image_files=image_files,  # List of image file paths to overlay
            image_positions=image_positions,  # Positions of each image on the canvas
            cropping_regions=cropping_regions,  # Optional cropping regions for the images
            image_scale_factors=image_scale_factors,  # Scaling factors for resizing the images
            output_image_file=output_image_file  # Path to save the final combined image
        )

        # Add time stamp
        text = "t = %.1f Ma" % (time_rounded / 1e6)  # Replace with the text you want to add
        position = (25, 0)  # Replace with the desired text position (x, y)
        font_path = "/usr/share/fonts/truetype/msttcorefonts/times.ttf"  # Path to Times New Roman font
        font_size = 56

        plot_helper.add_text_to_image(output_image_file, output_image_file, text, position, font_path, font_size)

        ani_file_paths.append(output_image_file)

    # Generate animation
    if not debug_step0_animate_2d_case_basic:
        ani_dir = os.path.join(local_dir_2d, "img", "animation")
        if not os.path.isdir(ani_dir):
            os.mkdir(ani_dir)
        output_file = os.path.join(local_dir_2d, "img", "animation", "%s.avi" % animation_name)
        create_avi_from_images(ani_file_paths, output_file, 1)