# Working with Landlab and PyGMT

The PyGMT package, a Python-wrapped version of Generic Mapping Tools, provides powerful capabilities for plotting and visualizing geoscientific data. PyGMT includes the capability of pulling in digital elevation model (DEM) data from a remote server. This notebook demonstrates several aspects of working with Landlab and PyGMT together:

- Downloading DEM data using PyGMT.
- Projecting from geographic to UTM coordinates.
- Converting a PyGMT elevation into a Landlab RasterModelGrid.
- Running a Landlab component using the converted GMT elevation grid.
- Converting derived data back to GMT grid format.
- Projecting back to geographic.
- Visualizing the output with PyGMT plotting functions.


## Requirements:

- Landlab
- PyGMT
- Numpy
- Matplotlib

*(Tutorial written by Greg Tucker, August 2024)*

## Importing a DEM using PyGMT

We start off using the PyGMT function `load_earth_relief` to get a DEM. Here we'll use the San Juan Islands, northwestern US, as an example.

Start with some imports:

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pygmt
import xarray as xr

from landlab import RasterModelGrid

Next, run PyGMT's `load_earth_relief` function. We give it the latitude and longitude coordinates of our region. We will request a resolution of three arc seconds (about 90 meters) per grid cell. The function returns an `xarray.DataArray` with the elevation values and some metadata.

In [None]:
san_juan_islands = [-123.24, -122.76, 48.41, 48.73]

gmtgrid = pygmt.datasets.load_earth_relief(
    resolution="03s",
    region=san_juan_islands,
)

In [None]:
# Get info about gmtgrid
gmtgrid

The above information tells us that `gmtgrid` is an Xarray `DataArray` containing a 2D array of values called `z`, with coordinates of rows and columns listed in `Lat` and `Lon`, and a variety of *attributes*.

Let's display the GMT elevation grid, using PyGMT's `grdimage` and `grdcontour` functions:

In [None]:
fig = pygmt.Figure()
fig.grdimage(
    grid=gmtgrid,
    frame=["a", "+tSan Juan Islands"],
    projection="M10c",  # M = Mercator; 10c = 10 cm wide
    cmap="oleron",
)
fig.grdcontour(grid=gmtgrid, levels=25.0, limit=[0, 750.0], annotation=None)
fig.colorbar(frame=["a100", "x+lElevation", "y+lm"])
fig.show()

Here's an alternative example using `grdview`:

In [None]:
fig = pygmt.Figure()
fig.grdview(
    grid=gmtgrid,
    perspective=[160, 30],
    frame=["xa", "yaf", "WSnE"],
    projection="M15c",
    zsize="1.5c",
    # Set the surftype to "surface"
    surftype="s",
    # Set the CPT (color palette) to "geo"
    cmap="geo",
)
fig.show()

### Projecting to UTM

Before we transfer this DEM into the format of a Landlab `RasterModelGrid`, we need to address a drawback of having a DEM in geographic (lat-lon) coordinates. Landlab components that operate on terrain often expect the horizontal and vertical units to be the same; that is, the $x$ and $y$ coordinates of the various grid elements are assumed to be in the same coordinate system and units as the field `topographic__elevation`. But in this example, so far we have a DEM that uses degrees for horizontal coordinates and meters for elevation values.

To get around this problem, we can project the DEM onto a coordinate system with distance-based rather than geographical coordinates in the horizontal. PyGMT provides a function to do this: `grdproject`. It uses interpolation to create a new grid with points now georeferenced in a given projected coordinate system. Here we'll project the grid data into the widely used the Universal Transverse Mercator (UTM) coordinate system. Note that because grid projection involves interpolation, some information will be smeared out in the process.

Below we define a function to project a GMT grid from geographic to UTM coordinates. We also define a small helper function to infer the UTM zone from the center longitude and latitude:

In [None]:
def utm_zone_from_lon(center_lon, center_lat):
    """Infer UTM zone from center lon and lat.

    Warning: this isn't perfect. For the most part the UTM
    system is divided into six-degree-wide N-S strips, but
    there are some exceptions that this method ignores.
    """
    zone = str(int(1 + (180 + center_lon) // 6.0))
    if center_lat >= 0.0:
        zone += "N"
    else:
        zone += "S"
    return zone


def gmt_geog_to_utm(gmt_grid):
    """
    Project a GMT lat-lon grid to a UTM grid at the same
    resolution.
    """
    zone = utm_zone_from_lon(
        0.5 * (np.amin(gmtgrid.lon) + np.amax(gmtgrid.lon)),
        0.5 * (np.amin(gmtgrid.lat) + np.amax(gmtgrid.lat)),
    )
    out_grid = pygmt.grdproject(
        grid=gmt_grid,
        projection="U" + str(zone) + "/12c",
        scaling="e",
        region=[
            float(np.amin(gmtgrid.lon)),
            float(np.amax(gmtgrid.lon)),
            float(np.amin(gmtgrid.lat)),
            float(np.amax(gmtgrid.lat)),
        ],
        center=[1.0e-6, 1.0e-6],  # some nonzero val needed to preserve offset...?
    )
    out_grid.attrs["units"] = gmt_grid.units
    return out_grid

The `gmt_geog_to_utm` function takes a GMT grid as an input. The GMT grid is assumed to be in geographic coordinates. To project it, we call the PyGMT `grdproject` function, giving it the following arguments:

- The GMT grid (an `xarray.DataArray`)
- The desired new projection: `U` for UTM plus the zone code and a desired width (here 12 cm)
- Scaling option `e`, which means we want to keep the same resolution, just converted to meters from degrees
- The region (the author does not know why the function doesn't just use the original region as a default, but apparently this needs to be specified, so we just tell it what the input region is)
- The `center` keyword, which tells PyGMT to keep the offset inherent in the new coordinate system, in this case UTM (apparently if we omit this or send (0, 0), the offset is lost, so here we give it a tiny positive value)

We also add a metadata attribute for the units, recording the fact that these are meters. Here we'll create a UTM-projected grid of our elevation data from the San Juan Islands: 

In [None]:
proj_gmt_grid = gmt_geog_to_utm(gmtgrid)
proj_gmt_grid

Note that this `DataArray` has `x` and `y` in place of longitude and latitude, and the horizontal coordinates are in meters.

Next, we define a function that creates a Landlab `RasterModelGrid` containing the GMT grid values as a **field**, and with the node `x` and `y` determined from the GMT grid's coordinates.

In [None]:
def gmt_grid_to_raster(gmt_grid, field_for_z="topographic__elevation", tol=1.0e-8):
    """
    Create and return a Landlab RasterModelGrid from a PyGMT
    grid object.

    Assumes regular grid-node spacing.

    Parameters
    ----------
    gmt_grid : PyGMT xarray DataArray
        PyGMT grid, eg from pygmt.datasets.load_earth_relief()
    field_for_z : string (optional)
        Name of field for the "z" values of the PyGMT grid
    tol : float (optional)
        Tolerance for unevenness of node spacing, lat/lon units (default 1e-8)

    Notes
    -----
    Preserves the lower-left corner coordinates but NOT
    information about projection or datum.
    """
    try:
        x = gmt_grid.x
        y = gmt_grid.y
    except AttributeError:  # if no .x and .y, we need to project from geographic
        print("Warning: grid appears to be in geographic coordinates")
        x = gmt_grid.lon
        y = gmt_grid.lat

    ny = gmt_grid.shape[0]
    nx = gmt_grid.shape[1]
    spacings_ns = np.diff(y)
    if np.any(np.abs(spacings_ns - spacings_ns[0]) > tol):
        print("Warning: non-uniform latitude spacing.")
        print("Max spacing difference", spacings_ns - spacings_ns[0], "deg")
    dy = spacings_ns[0]
    spacings_ew = np.diff(x)
    if np.any(np.abs(spacings_ew - spacings_ew[0]) > tol):
        print("Warning: non-uniform longitude spacing.")
        print("Max spacing difference", spacings_ew - spacings_ew[0], "deg")
    dx = spacings_ew[0]

    grid = RasterModelGrid((ny, nx), xy_spacing=(dx, dy), xy_of_lower_left=(x[0], y[0]))

    z = grid.add_field(
        field_for_z,
        gmt_grid.data.flatten(),
        at="node",
        units=gmt_grid.units,
        copy=True,
        clobber=False,
    )

    grid.status_at_node[np.isnan(z)] = grid.BC_NODE_IS_CLOSED
    z[np.isnan(z)] = 0.0

    return grid

This function starts out by testing whether the GMT grid is in geographic or projected coordinates. If it has `.x` and `.y` properties, then we assume it is already projected. If it does *not* have these, then we assume it's in geographic coordinates, so we warn the user and proceed with `lat` for `y` and `lon` for `x`. Note that `try...except` block is used for this test.

We test grid's north-south and east-west grid spacing. These should be uniform; they can be different for north-south and east-west but the spacing between each row should be the same from row to row, and the spacing between each column should be the same from column to column. To test this, we use the `numpy.diff()` function. If the spacing differs by less than a specified tolerance, all is well; otherwise, we issue a warning message.

Next we create a Landlab `RasterModelGrid` with the appropriate number of rows and columns, and appropriate spacings between rows and between columns. We also record the $(x, y)$ coordinates of the lower-left corner so that the correct UTM coordinates are retained.

For elevation values, we create a new **field**, the default name for which is `topographic__elevation` (but this can be overridden if we are dealing with some other gridded quantity).

During projection, it is possible that some grid nodes around the perimeter will have been assigned a `nan` value (for example, this commonly happens with projection operations). We give these boundary nodes status code `BC_NODE_IS_CLOSED` to indicate that these nodes do not contain valid data.

Here we'll apply this function to our example DEM, and use the grid's `imshow` plotting method to display the result:

In [None]:
llgrid = gmt_grid_to_raster(proj_gmt_grid)

In [None]:
llgrid.imshow(
    llgrid.at_node["topographic__elevation"],
    colorbar_label="Elevation (m)",
)
plt.title("San Juan Islands")

### Use the resulting Landlab grid in a component

As an example, here we use the `Radiation` component to calculate incident solar radiation on the terrain surface at noon on January 1st.

In [None]:
from landlab.components import Radiation

The `Radiation` component expects all core nodes to be above sea level, and of course in this case many of them are underwater. We'll set those underwater nodes to closed-boundary status, so that the `Radiation` component will ignore them.

In [None]:
is_underwater = llgrid.at_node["topographic__elevation"] < 0.0
llgrid.status_at_node[is_underwater] = llgrid.BC_NODE_IS_CLOSED

In [None]:
# Instantiate a Radiation component with our elevation data
# and the approximate latitude...
rad = Radiation(
    llgrid,
    latitude=48.5,
)

# ...and run it
rad.update()

Displpay the radiation ratio, which is the ratio of incident
solar radiation to what a flat surface would receive at the
same latitude and time of year:

In [None]:
llgrid.imshow(llgrid.at_cell["radiation__ratio_to_flat_surface"])

The default `imshow` settings are ok, and matplotlib
provides plenty of options for customization to make plots like
this one nicer. But the (Py)GMT package was specifically
designed for mapping, and provides lots of great features for
making nice maps. It's useful therefore if we can get our
data back into a format PyGMT can understand.

To do that, we'll first define a function to convert a field in a Landlab grid into a PyGMT-format grid:

In [None]:
def landlab_to_gmt_grid(
    grid,
    field,
    at="node",
    attrs=None,
):
    if not isinstance(grid, RasterModelGrid):
        raise KeyError("grid must be RasterModelGrid")
    if at == "node":
        nr = grid.number_of_node_rows
        nc = grid.number_of_node_columns
        x = grid.x_of_node[:nc]
        y = grid.y_of_node[: grid.number_of_nodes : nc]
    elif at == "cell":
        nr = grid.number_of_cell_rows
        nc = grid.number_of_cell_columns
        x = grid.x_of_node[grid.node_at_cell][:nc]
        y = grid.y_of_node[grid.node_at_cell][: grid.number_of_cells : nc]
    elif at == "corner":
        nr = grid.number_of_corner_rows
        nc = grid.number_of_corner_columns
        x = grid.x_of_corner
        y = grid.y_of_corner
    else:
        raise KeyError("'at' must be node, cell, or corner")

    gmtgrd = xr.DataArray(
        data=grid[at][field].reshape((nr, nc)),
        coords={
            "y": y,
            "x": x,
        },
        attrs=attrs,
    )
    return gmtgrd

Let's apply this to the radiation ratio field:

In [None]:
field = "radiation__ratio_to_flat_surface"
radgmtgrd = landlab_to_gmt_grid(
    llgrid,
    field,
    at="cell",
    attrs={"units": "m", "name": "radiation__ratio_to_flat_surface"},
)

To take full advantage of PyGMT's capabilities (at least for
a GMT newbie like this author), we'll project back into geographic coordinates:

In [None]:
btg = pygmt.grdproject(
    radgmtgrd,
    inverse=True,
    projection="U10N/15c",
    region=san_juan_islands,
    scaling="e",
    center=[1.0e-6, 1.0e-6],  # some nonzero val needed to preserve offset...?
)

Finally, let's use PyGMT to display one of the outputs of the Landlab calculation: the ratio of incident solar radiation to that on a flat surface. We'll display the PyGMT grid using `grdimage` and overlay coastlines using `coast`, coloring the oceans light blue. We'll use `text` to label the three largest islands, and add a `colorbar` at the bottom of the plot.

In [None]:
fig = pygmt.Figure()
fig.grdimage(
    btg,
    region=san_juan_islands,
    projection="M15c",  # Set Mercator projection and size of 15 centimeter
    frame="a",
    cmap="SCM/navia",
)
fig.coast(
    region=san_juan_islands,  # Set bounding box of the large figure
    borders="2/thin",  # Plot state boundaries with thin lines
    shorelines="thin",  # Plot coastline with thin lines
    projection="M15c",  # Set Mercator projection and size of 15 centimeter
    water="lightblue",  # Color water areas light blue
    frame="a",  # Set frame with annotation and major tick spacing
)
fig.text(
    x=[-123.08, -122.91, -122.9],
    y=[48.53, 48.61, 48.48],
    font="12p,Helvetica-Bold,white",
    text=[
        "San Juan",
        "Orcas",
        "Lopez",
    ],
)
fig.colorbar(frame=["a0.25", "x+lRadiation ratio"], truncate=[0.0, np.amax(btg.data)])
fig.show()