# Extract training data

This notebook will extract plate kinematic data from a plate model and other data from the `source_data` directory, writing the resulting dataset to a CSV file which can then be used to train the models in the following notebooks (`01*.ipynb`).

## Notebook options

These cells set some of the important variables and definitions used throughout the notebook.

### Select plate model

To use the plate model from the published paper (Alfonso et al., 2024), set `use_provided_plate_model` to `True`. Otherwise, leave `use_provided_plate_model` as `False` and set `plate_model_name` to a valid model name for the [`plate-model-manager`](https://github.com/michaelchin/plate-model-manager/blob/4f66423b53950bf42f5dac1228e61fd1e19fdf6e/models.json) package, or set `plate_model_name` to `None` and place GPlates files in a directory named `plate_model`.

| `use_provided_plate_model` | `plate_model_name` | result |
| - | - | - |
| `True` | Any | Use Alfonso et al., 2024 model |
| `False` | Model name string (e.g. `"muller2022"`) | Use specified plate model |
| `False` | `None` | Use files in `plate_model` directory |

In [1]:
use_provided_plate_model = True
plate_model_name = "muller2022"

### Set other parameters

Perhaps the most important options here are `n_jobs` and `max_time`.

In [2]:
# Allow reproducibility of randomised results
random_seed = 1234

# Number of processes to use
n_jobs = 4

# Overwrite any existing output files
overwrite = False

# Timespan for analysis
min_time = 0
max_time = 170

# Control verbosity level of logging output
verbose = False

# Number of unlabelled points to generate
num_unlabelled = 200  # per timestep

If any of the following exist as environment variables, they will replace the values defined above.

In [3]:
import os

# Override above values with environment variables, if they exist
n_jobs = int(os.environ.get("N_JOBS", n_jobs))
overwrite = bool(int(os.environ.get("OVERWRITE", overwrite)))
min_time = int(os.environ.get("MIN_TIME", min_time))
max_time = int(os.environ.get("MAX_TIME", max_time))
verbose = bool(int(os.environ.get("VERBOSE", verbose)))
num_unlabelled = int(os.environ.get("NUM_UNLABELLED", num_unlabelled))

times = range(min_time, max_time + 1)

## Notebook setup

Imports, definitions, etc.

### Imports

In [4]:
import warnings

import geopandas as gpd
import pandas as pd
from gplately.tools import plate_isotherm_depth

from lib.assign_regions import assign_regions
from lib.calculate_convergence import run_calculate_convergence
from lib.check_files import (
    check_source_data,
    check_plate_model,
)
from lib.combine_point_data import combine_point_data
from lib.coregister_combined_point_data import run_coregister_combined_point_data
from lib.coregister_crustal_thickness import run_coregister_crustal_thickness
from lib.coregister_ocean_rasters import (
    extract_subducted_thickness,
    run_coregister_ocean_rasters,
)
from lib.create_study_area_polygons import run_create_study_area_polygons
from lib.erodep import calculate_erodep
from lib.generate_unlabelled_points import generate_unlabelled_points
from lib.misc import calculate_slab_flux, calculate_carbon
from lib.plate_models import get_plate_reconstruction
from lib.slab_dip import calculate_slab_dip
from lib.water import calculate_water_thickness

# Suppress occasional joblib warnings
%env PYTHONWARNINGS=ignore::UserWarning
warnings.simplefilter("ignore", UserWarning)



### Input and output files

If necessary, the plate model will be downloaded:

In [5]:
plate_model_dir = "plate_model"
if use_provided_plate_model:
    check_plate_model(plate_model_dir, verbose=True)
    plate_model_name = None
plate_model = get_plate_reconstruction(
    model_name=plate_model_name,
    model_dir=plate_model_dir,
)

The directory containing the datasets to be extracted:

In [6]:
data_dir = "source_data"
data_dir = check_source_data(data_dir, verbose=verbose)

Output files will be created in this directory:

In [7]:
output_dir = "extracted_data"
os.makedirs(output_dir, exist_ok=True)

The following input directories are all relative to `data_dir`:

In [8]:
# CSV file with known deposits; columns:
# lon, lat, age (Ma)
deposits_filename = "deposit_data_global.csv"

# If desired, categorise deposits according to location
# Should be a shapefile or GeoJSON containing polygons
# with a 'region' attribute
regions_filename = "regions.geojson"

# Seafloor age grid directory
# Filename format 'seafloor_age_{time}Ma.nc'
agegrid_dir = "AgeGrids"

# Seafloor sediment thickness directory
# Filename format 'sediment_thickness_{time}Ma.nc'
sedthick_dir = "SedimentThickness"

# Seafloor carbonate sediment thickness directory
# Filename format 'carbonate_thickness_{time}Ma.nc'
carbonate_dir = "CarbonateThickness"

# Oceanic crustal CO2 density directory
# Filename format 'crustal_co2_{time}Ma.nc'
co2_dir = "CrustalCO2"

# Overriding plate thickness directory
# Filename format 'crustal_thickness_{time}Ma.nc'
crustal_thickness_dir = "CrustalThickness"

# Cumulative subducted sediments/carbonates/etc. directory
# Filename format 'sediment_thickness/cumulative_density_{time}Ma.nc',
# 'carbonate_thickness/cumulative_density_{time}Ma.nc', etc.
subducted_quantities_dir = "SubductedQuantities"

# Erosion/deposition rate directory
# Filename format 'erosion_deposition_{time}Ma.nc'
erodep_dir = "ErosionDeposition"

In [9]:
# Handle relative file/directory paths

deposits_filename = os.path.join(data_dir, deposits_filename)
regions_filename = os.path.join(data_dir, regions_filename)
agegrid_dir = os.path.join(data_dir, agegrid_dir)
sedthick_dir = os.path.join(data_dir, sedthick_dir)
carbonate_dir = os.path.join(data_dir, carbonate_dir)
co2_dir = os.path.join(data_dir, co2_dir)
crustal_thickness_dir = os.path.join(data_dir, crustal_thickness_dir)
subducted_quantities_dir = os.path.join(data_dir, subducted_quantities_dir)
erodep_dir = os.path.join(data_dir, erodep_dir)

subduction_data_filename = os.path.join(output_dir, "subducting_plate_data.csv")
study_area_dir = os.path.join(output_dir, "study_area_polygons")
output_filename = os.path.join(output_dir, "training_data_global.csv")

### Subducting plate data

This cell will extract the subduction kinematics data from the plate model, along with datasets relating to the subducting oceanic plate: seafloor age, sediment and carbonate thickness, etc.
However, if this data has already been extracted by another notebook and `overwrite` has not been set to `True`, then the data will be read from that file instead.

In [10]:
if (
    subduction_data_filename is not None and os.path.isfile(subduction_data_filename)
) and (not overwrite):
    subduction_data = pd.read_csv(subduction_data_filename)
else:
    subduction_data = run_calculate_convergence(
        nprocs=n_jobs,
        min_time=min(times),
        max_time=max(times),
        plate_reconstruction=plate_model,
        verbose=verbose,
    )

    subduction_data = run_coregister_ocean_rasters(
        nprocs=n_jobs,
        times=times,
        input_data=subduction_data,
        agegrid_dir=agegrid_dir,
        plate_reconstruction=plate_model,
        sedthick_dir=sedthick_dir,
        carbonate_dir=carbonate_dir,
        co2_dir=co2_dir,
        subducted_thickness_dir=os.path.join(
            subducted_quantities_dir,
            "plate_thickness",
        ),
        subducted_sediments_dir=os.path.join(
            subducted_quantities_dir,
            "sediment_thickness",
        ),
        subducted_carbonates_dir=os.path.join(
            subducted_quantities_dir,
            "carbonate_thickness",
        ),
        subducted_water_dir=os.path.join(
            subducted_quantities_dir,
            "water_thickness",
        ),
        verbose=verbose,
    )
    subduction_data["plate_thickness (m)"] = plate_isotherm_depth(
        subduction_data["seafloor_age (Ma)"],
        maxiter=100,
    )
    subduction_data = calculate_water_thickness(data=subduction_data)
    subduction_data = calculate_carbon(subduction_data)
    subduction_data = calculate_slab_flux(subduction_data)
    subduction_data = calculate_slab_dip(subduction_data)
    subduction_data = extract_subducted_thickness(
        subduction_data,
        plate_reconstruction=plate_model,
    )

    if subduction_data_filename is not None:
        subduction_data.to_csv(subduction_data_filename, index=False)

### Create study area polygons along subduction zones

Here we define our study area as all points on the overriding plate within a certain distance of the subduction zone (by default, $6 \degree, \approx 660\mathrm{km}$)

In [11]:
from lib.create_study_area_polygons import DEFAULT_SZ_BUFFER_DISTANCE

buffer_distance = DEFAULT_SZ_BUFFER_DISTANCE  # 6.0

if overwrite or not os.path.isdir(study_area_dir):
    run_create_study_area_polygons(
        nprocs=n_jobs,
        times=times,
        plate_reconstruction=plate_model,
        output_dir=study_area_dir,
        buffer_distance=buffer_distance,
        verbose=verbose,
        return_output=False,
    )

### Generate random unlabelled data points

The unlabelled set is created by generating uniformly-distributed random points within the polygons created in the previous cell. To change the number of points generated at each timestep, modify the `num_unlabelled` parameter defined earlier.

In [12]:
unlabelled = generate_unlabelled_points(
    times=times,
    input_dir=study_area_dir,
    num=num_unlabelled,
    threads=n_jobs,
    seed=random_seed,
    plate_reconstruction=plate_model,
    verbose=verbose,
)

### Combine labelled deposit/non-deposit data with random unlabelled data

The function below wrangles the points generated in the previous cell into the same format as the deposit location data.

In [13]:
combined_points = combine_point_data(
    deposit_data=deposits_filename,
    unlabelled_data=unlabelled,
    plate_reconstruction=plate_model,
    study_area_dir=study_area_dir,
    min_time=min(times),
    max_time=max(times),
    n_jobs=n_jobs,
    verbose=verbose,
)
del unlabelled
combined_points = combined_points.dropna(subset=["present_lon", "present_lat"])

### Assign subduction data to point deposit/non-deposit/unlabelled data

Here we assign the appropriate values for the subduction-related parameters (kinematics, seafloor age, etc.) to the deposit sites and random locations.

In [14]:
coregistered_data = run_coregister_combined_point_data(
    point_data=combined_points,
    subduction_data=subduction_data,
    n_jobs=n_jobs,
    verbose=verbose,
)
del combined_points, subduction_data

### Assign crustal thickness data to point data

This cell extracts the overriding plate thickness at each point.

In [15]:
coregistered_data = run_coregister_crustal_thickness(
    point_data=coregistered_data,
    input_dir=crustal_thickness_dir,
    n_jobs=n_jobs,
    verbose=verbose,
)

### Calculate cumulative erosion

Here we calculate the cumulative erosion experienced by each deposit/random point since its time of formation.

In [16]:
coregistered_data = calculate_erodep(
    coregistered_data,
    input_dir=erodep_dir,
    n_jobs=n_jobs,
    column_name="erosion (m)",
    verbose=verbose,
)

### Assign data to regions

To divide the data into individual regions for the later analysis, we use the `regions_filename` defined earlier, if desired.

In [17]:
if regions_filename is not None and os.path.isfile(regions_filename):
    points = gpd.GeoSeries.from_xy(
        coregistered_data["present_lon"],
        coregistered_data["present_lat"],
        index=coregistered_data.index,
    )
    coregistered_data["region"] = assign_regions(
        points,
        regions=regions_filename,
    )
    del points

### Save to file

Finally, we write the dataset to a CSV file.

In [18]:
coregistered_data.to_csv(output_filename, index=False)

coregistered_data.groupby(["region", "label"]).size()

region          label     
East Asia       negative         7
                positive         5
                unlabelled    4734
North America   negative        45
                positive       257
                unlabelled    7890
Other           negative       203
                positive         1
                unlabelled    3741
South America   negative      1096
                positive       211
                unlabelled    5709
Southeast Asia  negative         4
                positive        55
                unlabelled    7811
Tethys          negative        20
                positive        61
                unlabelled    5922
dtype: int64