# Calculate the forward gravity of ice and firn

We'll load some datasets of ice surface, ice thickness and firn thickness, convert them into layers of vertical prisms, assign each prism a density, and calculate the gravity effect of the layers of prisms. 

# Import packages

In [None]:
# %%capture
%load_ext autoreload
%autoreload 2

import harmonica as hm
import verde as vd
import xarray as xr
import numpy as np
import pyvista as pv
import pygmt
from typing import Union
from antarctic_plots import maps, utils, regions, fetch
import RIS_gravity_inversion.inversion as inv 

# Useful functions

In [None]:
def grids_to_prisms(
    top: xr.DataArray,
    bottom: xr.DataArray,
    density: Union[float or int or xr.DataArray],
    plot: bool = False,
    **kwargs,
):
    # if density provided as a single number, use it for all prisms
    if isinstance(density, (float, int)):
        dens = density * np.ones_like(top)
    # if density provided as a dataarray, map each density to the correct prisms
    elif isinstance(density, xr.DataArray):
        dens = density
    else:
        raise ValueError("invalid density type, should be a number or DataArray")

    # create layer of prisms based off input dataarrays
    prisms = hm.prism_layer(
        coordinates=(top.x.values, top.y.values),
        surface=top,
        reference=bottom,
        properties={
            'density': dens,
            'thickness': top-bottom,
        },
    )

    # plot prisms
    if plot is True:
        show_prisms(prisms, **kwargs)

    return prisms

def show_prisms(
    prisms: xr.Dataset, 
    cmap='viridis',
    color_by='thickness',
    **kwargs,
    ):

    pv.set_jupyter_backend(kwargs.get('backend', 'ipyvtklink'))

    pv_grid = prisms.prism_layer.to_pyvista()

    # Plot with pyvista
    plotter = pv.Plotter(
        lighting="three_lights", 
        # window_size=(1000, 1000),
        notebook=True,
        )
    plotter.add_mesh(
        pv_grid, 
        scalars=color_by,
        cmap=cmap, 
        flip_scalars=kwargs.get('flip_scalars', False),
        smooth_shading=True, 
        style='surface', 
        show_edges=False,
        )
    plotter.set_scale(zscale=kwargs.get('zscale', 75))  # exaggerate the vertical coordinate
    plotter.camera_position = kwargs.get('camera_position', "xz")
    plotter.camera.elevation = kwargs.get('elevation', 20)
    plotter.camera.azimuth = kwargs.get('azimuth', -25)
    plotter.camera.zoom(kwargs.get('zoom',1.2))

    # Add a ceiling light
    west, east, south, north = vd.get_region((prisms.easting, prisms.northing))
    easting_center, northing_center = (east + west) / 2, (north + south) / 2
    light = pv.Light(
        position=(easting_center, northing_center, 10e3),
        focal_point=(easting_center, northing_center, 0),
        intensity=0.3,
        light_type="scene light",  # the light doesn't move with the camera
        positional=False,  # the light comes from infinity
    )
    plotter.add_light(light)

    plotter.show_axes()
    plotter.show()

def show_prism_layers(
    prisms: list, 
    cmap='viridis',
    color_by='density',
    **kwargs,
    ):

    pv.set_jupyter_backend(kwargs.get('backend', 'ipyvtklink'))
    
    # Plot with pyvista
    plotter = pv.Plotter(
        lighting="three_lights", 
        # window_size=(1000, 1000),
        notebook=True,
        )
    for i, j in enumerate(prisms):
        pv_grid = prisms[i].prism_layer.to_pyvista()
        plotter.add_mesh(
            pv_grid, 
            scalars=color_by,
            cmap=cmap, 
            flip_scalars=kwargs.get('flip_scalars', False),
            smooth_shading=True, 
            style='surface', 
            show_edges=False,
            )
        plotter.set_scale(zscale=kwargs.get('zscale', 75))  # exaggerate the vertical coordinate
        plotter.camera_position = kwargs.get('camera_position', "xz")
        plotter.camera.elevation = kwargs.get('elevation', 20)
        plotter.camera.azimuth = kwargs.get('azimuth', -25)
        plotter.camera.zoom(kwargs.get('zoom',1.2))

    # Add a ceiling light
    west, east, south, north = vd.get_region((prisms[i].easting, prisms[i].northing))
    easting_center, northing_center = (east + west) / 2, (north + south) / 2
    light = pv.Light(
        position=(easting_center, northing_center, 10e3),
        focal_point=(easting_center, northing_center, 0),
        intensity=0.3,
        light_type="scene light",  # the light doesn't move with the camera
        positional=False,  # the light comes from infinity
    )
    plotter.add_light(light)

    plotter.show_axes()
    plotter.show()

# Gather datasets

## Bedmap2
* 1 km resolution
* "The ice thickness measurements compiled for Bedmap2, thus, represent the researchers’ best estimate of the physical ice thickness, rather than an “ice-equivalent” thickness".
* firn correction assumed to have been applied in all original Bedmap1 data

## Bedmachine
* 500m resolution
* Used REMA surface, and subtracted a firn-correction to get bedmachine_surface
* "The surface evelation and ice thickness are in ice equivalent as they include a firn air content correction. The elevation of the top of the snow, which is provided by REMA [Howat et al., 2019], can be calculated by adding the firn depth correction provided in the netCDF"
* "we derive the ice thickness of floating ice shelves by relying on the hydrostatic equilibrium"
* "The ice shelf thickness is therefore the sum of an ice equivalent ice thickness and firn depth correction"

## REMA
* 1km resolution (finer versions available)
* referenced to WGS-84 ellipsoid


Set the properties for all the grids so that they are identical

In [None]:
region=regions.ross_ice_shelf
spacing=5e3
registration='g'
reference='geoid'

Create a mask to retain only data within the Ross Ice Shelf

In [None]:
mask = utils.mask_from_shp(
    "plotting/MEaSUREs_RIS.shp", 
    masked=True,
    invert=False, 
    region=region, 
    spacing=spacing,
    pixel_register=False,
    )
print("Mask info:", utils.get_grid_info(mask))
mask.plot()

In [None]:
bedmap2_surface = fetch.bedmap2(
    layer='surface', 
    region=region,
    spacing=spacing,
    registration=registration,
    reference=reference,
    verbose='e',
    )

In [None]:
bedmap2_thickness = fetch.bedmap2(
    layer='thickness', 
    region=region,
    spacing=spacing,
    registration=registration,
    verbose='e',
    ) 

In [None]:
bedmap2_icebase = bedmap2_surface - bedmap2_thickness

# bedmap2_surface = bedmap2_surface * mask
# bedmap2_thickness = bedmap2_thickness * mask

print("Bedmap2 surface info:", utils.get_grid_info(bedmap2_surface))
print("Bedmap2 thickness info:", utils.get_grid_info(bedmap2_thickness))
print("Bedmap2 icebase info:", utils.get_grid_info(bedmap2_icebase)) 

In [None]:
bedmap2_ice_prisms = grids_to_prisms(
    bedmap2_surface,
    bedmap2_icebase,
    density=917,
    plot=True,
    color_by='thickness',
    cmap='viridis',
    # backend='static',
)

In [None]:
# # make an ocean mask 
# bedmachine_mask = fetch.bedmachine(
#     layer='mask', 
#     region=region,
#     spacing=spacing,
#     registration=registration,
#     reference=reference,
#     verbose='e',
#     )
# # 0 over ocean, 1 elsewhere
# ocean_mask = bedmachine_mask.where(bedmachine_mask == 0, 1, 0)
# ocean_mask = ocean_mask.where(ocean_mask != 0)

# get masked surface
bedmachine_surface = fetch.bedmachine(
    layer='surface', 
    region=region,
    spacing=spacing,
    registration=registration,
    reference=reference,
    verbose='e',
    ) * mask

# get masked thicknes
bedmachine_thickness = fetch.bedmachine(
    layer='thickness', 
    region=region,
    spacing=spacing,
    registration=registration,
    verbose='e',
    ) * mask

# get masked firn
bedmachine_firn = fetch.bedmachine(
    layer='firn', 
    region=region,
    spacing=spacing,
    registration=registration,
    reference=reference,
    verbose='e',
    ) * mask

# calculate ice base
bedmachine_icebase = bedmachine_surface - bedmachine_thickness

print("bedmachine surface info:", utils.get_grid_info(bedmachine_surface))
print("bedmachine thickness info:", utils.get_grid_info(bedmachine_thickness))
print("bedmachine icebase info:", utils.get_grid_info(bedmachine_icebase)) 
print("bedmachine firn info:", utils.get_grid_info(bedmachine_firn)) 

In [None]:
bedmachine_ice_prisms = grids_to_prisms(
    bedmachine_surface,
    bedmachine_icebase,
    density=917,
    plot=True,
    color_by='thickness',
    cmap='viridis',
    # backend='static',
)

In [None]:
bedmachine_firn_prisms = grids_to_prisms(
    bedmachine_surface + bedmachine_firn,
    bedmachine_surface,
    density=500,
    plot=True,
    color_by='thickness',
    cmap='viridis',
    # backend='static',
)

In [None]:
show_prism_layers(
    [bedmachine_ice_prisms, bedmachine_firn_prisms],
    cmap='viridis',
    color_by='density',
    )

In [None]:
REMA = fetch.REMA(
    version=1e3,
    region=region,
    spacing=spacing,
    # registration=registration,
    )
print("REMA surface info:", utils.get_grid_info(REMA))
# fig = maps.plot_grd(
#     REMA,
#     coast=True,
#     region=region,
#     grd2cpt=True,
#     cmap='haline',
#     cbar_label='m',
#     title='REMA surface',
#     shading="+a45+nt0.2",
#     )
# fig.show()

# Forward gravity

## Bedmap2 

$
Z_{top} = Z_{surface}
$

$
Z_{bottom} = Z_{surface} - t_{ice}
$

## Bedmachine (ice equivalent thickness)

$
Z_{top} = Z_{surface}
$

$
Z_{bottom} = Z_{surface} - t_{ice}
$

## Bedmachine2 (true geometry)

"The surface elevation and ice thickness are in ice equivalent as they include a firn air content correction. The elevation of the top of the snow, which is provided by REMA [Howat et al., 2019], can be calculated by **adding the firn depth** correction provided in the netCDF"

* $
Z_{top} = Z_{surface} + t_{firn}
$

"The ice shelf thickness, $t_{shelf}$, is therefore the sum of an ice equivalent ice thickness, $t_{ice}$, and firn depth correction, $t_{firn}$."

* $
t_{shelf} = t_{ice} + t_{firn}
$

Therefore,

* $
Z_{bottom} = Z_{top} - t_{shelf}
$

Substituting in $Z_{top}$ and $t_{shelf}$,

* $
Z_{bottom} = (Z_{surface} + t_{firn}) - (t_{ice} + t_{firn})
$

Simplifying

* $
Z_{bottom} = Z_{surface} - t_{ice}
$


Here, we'll use the forward gravity calculations provided by _Harmonica_ to calculate the gravity of each vertical prism. Convienently, we can do this for the entire layer of prisms, instead of for individual prisms, using **harmonica.prism_layer.gravity**.

We need to tell harmonica where (spatially) to calculate the gravity effect of the prisms. We can create a grid of evenly spaced observations points at a constant elevation.

In [None]:
spacing=10e3
observation_points = vd.grid_coordinates(region, spacing=spacing, extra_coords=1e3)

In [None]:
grav = bedmap2_ice_prisms.prism_layer.gravity(
    coordinates = observation_points,
    field = 'g_z',
    progressbar = True,
)

bedmap2_ice_grav =vd.make_xarray_grid(
    coordinates = observation_points,
    data = grav,
    data_names = 'grav',
    extra_coords_names = 'height',
).grav

fig = maps.plot_grd(
    bedmap2_ice_grav,
    grd2cpt=True,
    cbar_label="mGal",
    title="Forward gravity of Bedmap2 ice prisms",
    coast=True,
    )
fig.show()


In [None]:
grav = bedmachine_ice_prisms.prism_layer.gravity(
    coordinates = observation_points,
    field = 'g_z',
    progressbar = True,
)

bedmachine_ice_grav =vd.make_xarray_grid(
    coordinates = observation_points,
    data = grav,
    data_names = 'grav',
    extra_coords_names = 'height',
).grav

fig = maps.plot_grd(
    bedmachine_ice_grav,
    grd2cpt=True,
    cbar_label="mGal",
    title="Forward gravity of Bedmachine ice prisms",
    coast=True,
    )
fig.show()


In [None]:
diff = utils.grd_compare(bedmap2_ice_grav, bedmachine_ice_grav, plot=True)