# Gplate for slab dataset

In this notebook, I use gplately to extract slab dataset, resample and plot the results

## Prerequisite

- Install the gplately package in a conda environment. Refer to their home page and there installation link [https://github.com/GPlates/gplately](https://github.com/GPlates/gplately)
- Download this package and set path to the installation directory

In [None]:
# use the environment of py-gplate
import sys
import gplately
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import gridspec
import cartopy.crs as ccrs
from plate_model_manager import PlateModelManager
from shutil import rmtree

# Include this pakage
# change to your download path
HaMaGeoLib_DIR = "/home/lochy/ASPECT_PROJECT/HaMaGeoLib"
if os.path.abspath(HaMaGeoLib_DIR) not in sys.path:
    sys.path.append(os.path.abspath(HaMaGeoLib_DIR))

from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import GPLOTTER
from hamageolib.utils.exception_handler import my_assert

# Set up

In [None]:
# enter the directory of the plate reconstruction files
case_dir = "/mnt/lochy/ASPECT_DATA/ThDSubduction/gplate_dataset"
if not os.path.isdir(case_dir):
  os.mkdir(case_dir)
if not os.path.isdir(os.path.join(case_dir, "img")):
  os.mkdir(os.path.join(case_dir, "img"))
csv_dir = os.path.join(case_dir, "csv")
if not os.path.isdir(csv_dir):
  os.mkdir(csv_dir)

# assign a reconstruction time
model_name = "Muller2019"
reconstruction_time=40 # time of reconstruction, must be integar
assert(type(reconstruction_time) == int)
anchor_plate_id = 0 # anchor plate id: 0 - Africa

# set up a directory to output for every step
img_dir = os.path.join(os.path.join(case_dir, "img", "%05dMa" % reconstruction_time))
if not os.path.isdir(img_dir):
  os.mkdir(img_dir)

# GplateP = GPLATE_PROCESS(case_dir)

# Import a lookup file for slab names

In [None]:
# parse the name lookup file
# These files are exported from gplate (GUI)
parse_name_lookup = False
if parse_name_lookup:
    from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import read_subduction_reconstruction_data

    subduction_name_lookup_file = os.path.join(case_dir, "Muller_etal_2019_PlateBoundaries_no_topologies",\
                                              "reconstructed_%.2fMa.xy" % float(reconstruction_time))
    name_lookups = read_subduction_reconstruction_data(subduction_name_lookup_file)

## Lookup for a key word

In [None]:
lookup_key_word = False

if parse_name_lookup and lookup_key_word:

    # print(name_lookups["trench_names"]) # debug

    keyword = "ryu"

    matching_indices = [i for i, name in enumerate(name_lookups["trench_names"]) if keyword.lower() in name.lower()]
    for index in matching_indices:
        print(index)
        print("name: ", name_lookups["trench_names"][index])
        print("id: ", name_lookups["trench_pids"][index])
        print("")

## lookup by trench pids

In [None]:
lookup_by_trench_pid = False

if parse_name_lookup and lookup_by_trench_pid:

    pid_lookup_list = [12001, 686, 736, 651, 669, 612, 678, 648, 659, 699, 111, 406, 413, 2000, 201, 2031, 2011, 815, 821]

    for index, pid in enumerate(name_lookups["trench_pids"]):
        if pid in pid_lookup_list:
            print("pid = %d, name = %s" % (pid, name_lookups["trench_names"][index]))

# Import reconstruction dataset

In [None]:
# GplateP.reconstruct(model_name, reconstruction_time, anchor_plate_id)

# Create an instance of the PlateModelManager to manage plate models
pm_manager = PlateModelManager()

# Load the "Muller2019" plate model from the specified data directory
plate_model = pm_manager.get_model("Muller2019", data_dir=os.path.join(case_dir, "plate-model-repo"))

# Set up the PlateReconstruction model using the loaded plate model data
# This includes rotation models, topologies, and static polygons, with the specified anchor plate ID
# anchor_plate_id - anchor plate ID for the reconstruction model, 0 for Africa
model = gplately.PlateReconstruction(
    plate_model.get_rotation_model(), 
    plate_model.get_topologies(), 
    plate_model.get_static_polygons(),
    anchor_plate_id=anchor_plate_id
)

# get the reconstruction of subduction zones
subduction_data_raw = model.tessellate_subduction_zones(reconstruction_time, 
                                                    tessellation_threshold_radians=0.01, 
                                                    anchor_plate_id=anchor_plate_id,
                                                    ignore_warnings=True)

# Define the columns used in the subduction data DataFrame
all_columns = ['lon', 'lat', 'conv_rate', 'conv_angle', 'trench_velocity', 
                          'trench_velocity_angle', 'arc_length', 'trench_azimuth_angle', 
                          'subducting_pid', 'trench_pid']

all_columns_1 = all_columns + ['age'] # leave a placeholder for age

subduction_data = pd.DataFrame(subduction_data_raw, columns=all_columns)

## Initiate a plotter

In [None]:
gPlotter = GPLOTTER(plate_model, model)
gPlotter.set_time(reconstruction_time)

# Inspect the original dataset

In [None]:
plot_raw_dataset = False

if plot_raw_dataset:
    fig = plt.figure(figsize=(10,6), dpi=100)
    ax = fig.add_subplot(111, projection=ccrs.Mollweide(central_longitude = 180))
    gPlotter.plot_global_basics(ax)
    ax.scatter(subduction_data.lon, subduction_data.lat, marker=".", s=30, c='r', transform=ccrs.PlateCarree())

# Add age data

In [None]:
# GplateP.add_age_raster()
# GplateP.export_csv("subduction_data", "ori.csv")

# Initialize the age grid raster, which will be used for age-related computations
age_grid_raster = gplately.Raster(
                                data=plate_model.get_raster("AgeGrids",reconstruction_time),
                                plate_reconstruction=model,
                                extent=[-180, 180, -90, 90]
                                )
# fill Nan values, it seems to not cause any issue in interpolating the ages.
# otherwise, there are many points where the trench point are not covered in the Raster.
# Thus, it seems these points are just on the boundary where some other value could be filled.
age_grid_raster.fill_NaNs(inplace=True)

subduction_data['age'] = age_grid_raster.interpolate(subduction_data.lon, subduction_data.lat, method="nearest")

file_path = os.path.join(csv_dir, "ori.csv")
subduction_data.to_csv(file_path, index=False)
print("Saved file %s" % file_path)

# Inspect and save results of every subduction

In [None]:
# GplateP.save_results_ori(inspect_all_slabs_in_separate_plots)


inspect_all_slabs = False
inspect_all_slabs_in_separate_plots = True

if inspect_all_slabs:

    # import packages 
    from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import crop_region_by_data

    local_img_dir = os.path.join(img_dir, "ori")
    if os.path.isdir(local_img_dir):
        rmtree(local_img_dir)
    os.mkdir(local_img_dir)

    subducting_pids = subduction_data["subducting_pid"].unique()

    # start figure
    # plot the subducting_pid in the globe
    fig0 = plt.figure(figsize=(10, 6), dpi=100)

    gPlotter.set_region("default")

    ax0 = fig0.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
    
    gPlotter.plot_global_basics(ax0, age_grid_raster=age_grid_raster)
    
    color_dict = gPlotter.plot_subduction_pts(ax0, subduction_data) 
    # test adding convergence vectors
    # plot_conv=True, stepping=5)

    for i, subducting_pid in enumerate(subducting_pids):
        one_subduction_data = subduction_data[subduction_data.subducting_pid==subducting_pid]

        # add marker to summary plot
        ax0.text(one_subduction_data["lon"].iloc[0], one_subduction_data["lat"].iloc[0], str(subducting_pid), transform=ccrs.PlateCarree(),
            fontsize=8,
            ha="left",   # horizontal alignment
            va="bottom"  # vertical alignment
        )

        if inspect_all_slabs_in_separate_plots:
            # plot individual subduction zone
            # 1. mark indices of points
            # 2. mark trench pid values
            # 3. also include makers
            region = crop_region_by_data(one_subduction_data, 15.0)
            gPlotter.set_region(region) # set region to default, no need to filter
            print("region: ", region) # debug
        
            fig = plt.figure(figsize=(10, 18), dpi=100)
            gs = gridspec.GridSpec(3, 1)
        
            ax = fig.add_subplot(gs[0, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, color=color_dict[subducting_pid])
            for j, (lon, lat) in enumerate(zip(one_subduction_data.lon, one_subduction_data.lat)):
                if j % 10 != 0:
                    continue
                ax.text(lon, lat, str(j), transform=ccrs.PlateCarree(),
                    fontsize=8,
                    ha="left",   # horizontal alignment
                    va="bottom"  # vertical alignment
                )
            ax.set_extent(region, crs=ccrs.PlateCarree())

            ax = fig.add_subplot(gs[1, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, "trench_pid")
            trench_pids = one_subduction_data["trench_pid"].unique()
            ax.set_extent(region, crs=ccrs.PlateCarree())

            ax = fig.add_subplot(gs[2, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, "trench_pid")
            trench_pids = one_subduction_data["trench_pid"].unique()
            ax.set_extent(region, crs=ccrs.PlateCarree())
            for j, trench_pid in enumerate(trench_pids):
                sub_subduction_data = one_subduction_data[one_subduction_data.trench_pid==trench_pid]
                lon, lat = sub_subduction_data.lon.iloc[0], sub_subduction_data.lat.iloc[0]
                ax.text(lon, lat, str(trench_pid), transform=ccrs.PlateCarree(),
                    fontsize=8,
                    ha="left",   # horizontal alignment
                    va="bottom"  # vertical alignment
                )
        
            # save figure of individual subduction zone
            ofile_path = os.path.join(local_img_dir, "global_subduction_ori_t%.2fMa_pid%06d" % (float(reconstruction_time), subducting_pid))
            fig.savefig(ofile_path + ".png")
            print("Saved figure %s" % (ofile_path + ".png"))
            fig.savefig(ofile_path + ".pdf")
            print("Saved figure %s" % (ofile_path + ".pdf"))

        # break

    ax0.set_extent((-180, 180, -90, 90), crs=ccrs.PlateCarree())
    ofile_path = os.path.join(local_img_dir, "global_subduction_ori_t%.2fMa" % float(reconstruction_time))
    fig0.savefig(ofile_path + ".png")
    print("Saved figure %s" % (ofile_path + ".png"))
    fig0.savefig(ofile_path + ".pdf")
    print("Saved figure %s" % (ofile_path + ".pdf"))


# Resample the dataset

In [None]:
# GplateP.resample_subduction(arc_length_edge, arc_length_resample_section)


resample_dataset = True

if resample_dataset:

    # parameters for resampling 
    arc_length_edge = 2.0; arc_length_resample_section = 2.0

    # load module
    from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import resample_subduction  

    # get all subducting id values
    subducting_pids = subduction_data["subducting_pid"].unique()
    trench_pids = subduction_data["trench_pid"].unique()

    print("Total subduction zones: ", len(subducting_pids))
    print("Total trenches: ", len(trench_pids))
    print("subducting_pids: ", subducting_pids)

    # plot data by trench_pid
    data_list = [pd.DataFrame(columns=all_columns_1)]
    for i, subducting_pid in enumerate(subducting_pids):
        mask = subduction_data.subducting_pid == subducting_pid
        one_subduction_data = subduction_data[mask]
        try:
            one_subduction_data_resampled = resample_subduction(one_subduction_data, arc_length_edge, arc_length_resample_section)
        except ValueError:
            continue
        data_list.append(one_subduction_data_resampled)
        
    subduction_data_resampled = pd.concat(data_list)


    re_subducting_pids = subduction_data_resampled["subducting_pid"].unique()
    re_trench_pids = subduction_data_resampled["trench_pid"].unique()
    print("Total subduction zones (after resampling): ", len(re_subducting_pids))
    print("Total trenches (after resampling): ", len(re_trench_pids))
    print("subducting_pids (after resampling): ", re_subducting_pids)
    print("Total resampled points: ", len(subduction_data_resampled))

    file_path = os.path.join(csv_dir, "resampled_edge%.1f_section%.1f.csv" % (arc_length_edge, arc_length_resample_section))
    subduction_data_resampled.to_csv(file_path, index=False)
    print("Saved file %s" % file_path)


# Inspect and save results of resampled dataset

In [None]:
# GplateP.save_results_resampled(inspect_all_slabs_resampled_plot_individual)

inspect_all_slabs_resampled = False
inspect_all_slabs_resampled_plot_individual = True

if resample_dataset and inspect_all_slabs_resampled:
    
    from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import crop_region_by_data

    local_img_dir = os.path.join(img_dir, "resampled_edge%.1f_section%.1f" % (arc_length_edge, arc_length_resample_section))
    if os.path.isdir(local_img_dir):
        rmtree(local_img_dir)
    os.mkdir(local_img_dir)

    subducting_pids = subduction_data_resampled["subducting_pid"].unique()

    # start figure
    # plot the subducting_pid in the globe
    fig0 = plt.figure(figsize=(10, 6), dpi=100)

    gPlotter.set_region("default")

    ax0 = fig0.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
    
    gPlotter.plot_global_basics(ax0, age_grid_raster=age_grid_raster)
    
    color_dict = gPlotter.plot_subduction_pts(ax0, subduction_data_resampled)

    for i, subducting_pid in enumerate(subducting_pids):
        one_subduction_data = subduction_data_resampled[subduction_data_resampled.subducting_pid==subducting_pid]

        # add marker to summary plot
        ax0.text(one_subduction_data["lon"].iloc[0], one_subduction_data["lat"].iloc[0], str(subducting_pid), transform=ccrs.PlateCarree(),
            fontsize=8,
            ha="left",   # horizontal alignment
            va="bottom"  # vertical alignment
        )

        # plot individual subduction zone
        if inspect_all_slabs_resampled_plot_individual:
            # plot individual subduction zone
            # 1. mark indices of points
            # 2. mark trench pid values
            # 3. also include makers
            region = crop_region_by_data(one_subduction_data, 15.0)
            gPlotter.set_region(region) # set region to default, no need to filter
            print("region: ", region) # debug
        
            fig = plt.figure(figsize=(10, 18), dpi=100)
            gs = gridspec.GridSpec(3, 1)
        
            ax = fig.add_subplot(gs[0, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, color=color_dict[subducting_pid])
            for j, (lon, lat) in enumerate(zip(one_subduction_data.lon, one_subduction_data.lat)):
                if j % 2 != 0:
                    continue
                ax.text(lon, lat, str(j), transform=ccrs.PlateCarree(),
                    fontsize=8,
                    ha="left",   # horizontal alignment
                    va="bottom"  # vertical alignment
                )
            ax.set_extent(region, crs=ccrs.PlateCarree())

            ax = fig.add_subplot(gs[1, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, "trench_pid")
            trench_pids = one_subduction_data["trench_pid"].unique()
            ax.set_extent(region, crs=ccrs.PlateCarree())

            ax = fig.add_subplot(gs[2, 0], projection=ccrs.PlateCarree(central_longitude=gPlotter.get_central_longitude()))
            gPlotter.plot_global_basics(ax, age_grid_raster=age_grid_raster)
            gPlotter.plot_subduction_pts(ax, one_subduction_data, "trench_pid")
            trench_pids = one_subduction_data["trench_pid"].unique()
            ax.set_extent(region, crs=ccrs.PlateCarree())
            for j, trench_pid in enumerate(trench_pids):
                sub_subduction_data = one_subduction_data[one_subduction_data.trench_pid==trench_pid]
                lon, lat = sub_subduction_data.lon.iloc[0], sub_subduction_data.lat.iloc[0]
                ax.text(lon, lat, str(trench_pid), transform=ccrs.PlateCarree(),
                    fontsize=8,
                    ha="left",   # horizontal alignment
                    va="bottom"  # vertical alignment
                )
        
            # save figure of individual subduction zone
            ofile_path = os.path.join(local_img_dir, "global_subduction_resampled_t%.2fMa_pid%06d" % (float(reconstruction_time), subducting_pid))
            fig.savefig(ofile_path + ".png")
            print("Saved figure %s" % (ofile_path + ".png"))
            fig.savefig(ofile_path + ".pdf")
            print("Saved figure %s" % (ofile_path + ".pdf"))


    ax0.set_extent((-180, 180, -90, 90), crs=ccrs.PlateCarree())
    ofile_path = os.path.join(local_img_dir, "global_subduction_resampled_t%.2fMa" % float(reconstruction_time))
    fig0.savefig(ofile_path + ".png")
    print("Saved figure %s" % (ofile_path + ".png"))
    fig0.savefig(ofile_path + ".pdf")
    print("Saved figure %s" % (ofile_path + ".pdf"))


# Analysis

In [None]:
do_analysis = True

if do_analysis:

    # options
    use_resampled_data = True # use resampled dataset

    if use_resampled_data:
        assert(resample_dataset)
        s_data = subduction_data_resampled
    else:
        s_data = subduction_data


## Preparation

### define additional markers

In [None]:
if do_analysis:

    from matplotlib.path import Path

    verts = [
        (0., 0.),   # Center
        (0.2, 0.6), # Upper arm
        (0., 0.),   # Center
        (0.4, 0.4), # Right diagonal
        (0., 0.),   # Center
        (0.6, 0.2), # Right arm
        (0., 0.),   # Center
        (0.4, -0.4),# Right down diagonal
        (0., 0.),   # Center
        (0.2, -0.6),# Bottom arm
        (0., 0.),   # Center
        (-0.4, -0.4),# Left down diagonal
        (0., 0.),   # Center
        (-0.6, -0.2),# Left arm
        (0., 0.),   # Center
        (-0.4, 0.4),# Left diagonal
        (0., 0.),   # Center
        (-0.2, 0.6),# Upper left arm
    ]
    codes = [Path.MOVETO] + [Path.LINETO, Path.MOVETO] * 8 + [Path.MOVETO]
    snowflake = Path(verts, codes)

    # Define vertices for two equilateral triangles
    vertices = [
        [0, 1], [-np.sqrt(3)/2, -0.5], [np.sqrt(3)/2, -0.5], [0, 1],  # First triangle
        [0, -1], [-np.sqrt(3)/2, 0.5], [np.sqrt(3)/2, 0.5], [0, -1]   # Second triangle
    ]
    # Flatten the vertices list for creating the Path
    vertices = np.array(vertices)
    # Define path codes (all 'LINETO' except the start 'MOVETO')
    codes = [Path.MOVETO] + [Path.LINETO] * (len(vertices) - 1)
    star_path = Path(vertices, codes)

### plot options for cases

Notes:

- The plot_by_name option plots the data points with their assigned names. These names have to be specified. Be default, set to False and the subducting_pid will be used to plot.

In [None]:
if do_analysis:

    # Plot options
    plot_by_name = False  # True - by name; False - by pid
    
    # Retrieve the default color cycle
    default_colors = [color['color'] for color in plt.rcParams['axes.prop_cycle']]

    # assign plot options
    if plot_by_name:
        if reconstruction_time == 0:
            plot_options = \
            (
                (903, {"marker": 'o',  "markerfacecolor": "yellow", "name": "CAS"}),
                (511, {"marker": 's',  "markerfacecolor": "yellow", "name": "ANDA-SUM"}),
                (801, {"marker": 'd',  "markerfacecolor": "yellow", "name": "JAVA"}),
                (645, {"marker": snowflake,  "markerfacecolor": "black", "name": "SULA"}),
                (602, {"marker": 'x',  "markerfacecolor": "blue", "name": "LUZ"}),
                (608, {"marker": 's',  "markerfacecolor": 'c', "name": "PHIL"}),
                ({901: 699}, {"marker": '>',  "markerfacecolor": 'red', "name": "MAR"}),
                ({901: 659}, {"marker": 's',  "markerfacecolor": 'red', "name": "IZU"}),
                ({901: (601115.0, 601118.0)}, {"marker": '^',  "markerfacecolor": 'green', "name": "JAP"}),
                ({901: 406}, {"marker": 'v',  "markerfacecolor": 'green', "name": "KUKAM"}),
                ({901: 111}, {"marker": 'o',  "markerfacecolor": 'pink', "name": "ALE-ALA"}),
                ({901: (806, 821)}, {"marker": 'd',  "markerfacecolor": 'blue', "name": "TON-KERM"}),
                (909, {"marker": star_path,  "markerfacecolor": 'c', "name": "MEX"}),
                (911, {"marker": 'o',  "markerfacecolor": 'k', "name": "PER-NCHI-JUAN-SCHI"}),
                (802, {"marker": 'd',  "markerfacecolor": 'k', "name": "SSCHI-TBD"}),
                ({201: 2011}, {"marker": '+',  "markerfacecolor": 'pink', "name": "ANT"}),
                ({201: 815}, {"marker": '*',  "markerfacecolor": 'r', "name": "SAND"}),
                (1, {"marker": 'd',  "markerfacecolor": "r", "name": "RYU"})
            )
        else:
            raise NotImplementedError()
    else:
        markers = ["o", '*', "d", "x", "v", "s"]
        n_color = 10
        plot_options = []
        subducting_pids = s_data.subducting_pid.unique()
        for i, subducting_pid in enumerate(subducting_pids):
            plot_options.append((int(subducting_pid), {"marker": markers[i//n_color],  "markerfacecolor": default_colors[i%10], "name": str(int(subducting_pid))}))

## Generate vs age plots

In [None]:
analyze_age_combined = True

# GplateP.plot_age_combined(plot_options, resample_dataset)

if do_analysis:

    from matplotlib import rcdefaults
    from matplotlib.ticker import MultipleLocator
    from hamageolib.research.haoyuan_3d_subduction.gplately_utilities import parse_subducting_trench_option, plot_age_combined
    import hamageolib.utils.plot_helper as plot_helper

    if resample_dataset:
        local_img_dir = os.path.join(img_dir, "age_combined_resampled_edge%.1f_section%.1f" %\
                                     (arc_length_edge, arc_length_resample_section))
    else:
        local_img_dir = os.path.join(img_dir, "age_combined_t%.2fMa" % float(reconstruction_time))

    if os.path.isdir(local_img_dir):
        rmtree(local_img_dir)
    os.mkdir(local_img_dir)

    for sub_plot_options in plot_options:
        print("sub_plot_options: ", sub_plot_options)
        _name = sub_plot_options[1]["name"]
        fig, axes = plot_age_combined(s_data, [sub_plot_options], plot_index=True)
        file_path = os.path.join(local_img_dir, "age_combined_%s" % _name)
        fig.savefig(file_path + ".png")
        print("Saved figure %s" % (file_path + ".png"))
        fig.savefig(file_path + ".pdf")
        print("Saved figure %s" % (file_path + ".pdf"))
        
    
    fig, _ = plot_age_combined(s_data, plot_options)
    file_path = os.path.join(local_img_dir, "age_combined")
    fig.savefig(file_path + ".png")
    print("Saved figure %s" % (file_path + ".png"))
    fig.savefig(file_path + ".pdf")
    print("Saved figure %s" % (file_path + ".pdf"))