Skip to content

Commit

Permalink
fix: replace pygmt gridding with simple set_index for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed May 7, 2024
1 parent 9218e68 commit ebf6732
Showing 1 changed file with 6 additions and 45 deletions.
51 changes: 6 additions & 45 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ def grid_inversion_results(
prisms_ds: xr.Dataset,
grav_results: pd.DataFrame,
region: tuple[float, float, float, float],
spacing: float,
registration: str,
) -> tuple[list[xr.DataArray], list[xr.DataArray], list[xr.DataArray]]:
"""
create grids from the various data variables of the supplied gravity dataframe and
Expand All @@ -266,10 +264,6 @@ def grid_inversion_results(
resulting dataframe of gravity data from the inversion
region : tuple[float, float, float, float]
region to use for gridding in format (xmin, xmax, ymin, ymax)
spacing : float
grid spacing in meters
registration : str
grid registration type, either "g" for gridline or "p" for pixel
Returns
-------
Expand All @@ -278,13 +272,7 @@ def grid_inversion_results(
"""
misfit_grids = []
for m in misfits:
grid = pygmt.xyz2grd(
data=grav_results[["easting", "northing", m]],
region=region,
spacing=spacing,
registration=registration,
verbose="q",
)
grid = grav_results.set_index(["northing", "easting"]).to_xarray()[m]
misfit_grids.append(grid)

topo_grids = []
Expand Down Expand Up @@ -414,8 +402,6 @@ def plot_inversion_topo_results(
def plot_inversion_grav_results(
grav_results: pd.DataFrame,
region: tuple[float, float, float, float],
spacing: float,
registration: str,
iterations: list[int],
) -> None:
"""
Expand All @@ -427,37 +413,22 @@ def plot_inversion_grav_results(
resulting dataframe of gravity data from the inversion
region : tuple[float, float, float, float]
region to use for gridding in format (xmin, xmax, ymin, ymax)
spacing : float
grid spacing in meters
registration : str
grid registration type, either "g" for gridline or "p" for pixel
iterations : list[int]
list of all the iteration numbers
"""
initial_misfit = pygmt.xyz2grd(
data=grav_results[["easting", "northing", "iter_1_initial_misfit"]],
region=region,
spacing=spacing,
registration=registration,
verbose="q",
)

final_misfit = pygmt.xyz2grd(
data=grav_results[
["easting", "northing", f"iter_{max(iterations)}_final_misfit"]
],
region=region,
spacing=spacing,
registration=registration,
verbose="q",
)
grid = grav_results.set_index(["northing", "easting"]).to_xarray()

initial_misfit = grid["iter_1_initial_misfit"]
final_misfit = grid[f"iter_{max(iterations)}_final_misfit"]

initial_rmse = utils.rmse(grav_results["iter_1_initial_misfit"])
final_rmse = utils.rmse(grav_results[f"iter_{max(iterations)}_final_misfit"])

_ = polar_utils.grd_compare(
initial_misfit,
final_misfit,
region=region,
plot=True,
grid1_name=f"Initial misfit: RMSE={round(initial_rmse, 2)} mGal",
grid2_name=f"Final misfit: RMSE={round(final_rmse, 2)} mGal",
Expand Down Expand Up @@ -682,8 +653,6 @@ def plot_inversion_results(
topo_results: pd.DataFrame | str,
parameters: dict[str, typing.Any] | str,
grav_region: tuple[float, float, float, float] | None,
grav_spacing: float,
registration: str = "g",
iters_to_plot: int | None = None,
plot_iter_results: bool = True,
plot_topo_results: bool = True,
Expand All @@ -703,10 +672,6 @@ def plot_inversion_results(
inversion parameters dictionary or filename
grav_region : tuple[float, float, float, float] | None
region to use for gridding in format (xmin, xmax, ymin, ymax), by default None
grav_spacing : float
grid spacing in meters
registration : str, optional
grid registration type, either "g" for gridline or "p" for pixel, by default "g"
iters_to_plot : int | None, optional
number of iterations to plot, including the first and last, by default None
plot_iter_results : bool, optional
Expand Down Expand Up @@ -774,8 +739,6 @@ def plot_inversion_results(
prisms_ds,
grav_results,
grav_region,
grav_spacing,
registration,
)

if plot_iter_results is True:
Expand All @@ -800,8 +763,6 @@ def plot_inversion_results(
plot_inversion_grav_results(
grav_results,
grav_region,
grav_spacing,
registration,
iterations,
)

Expand Down

0 comments on commit ebf6732

Please sign in to comment.