Skip to content

Commit

Permalink
feat: add inversion_region to plot convergence
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed May 7, 2024
1 parent b110f23 commit 52eebdd
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def plot_convergence(
results: pd.DataFrame,
iter_times: list[float] | None = None,
logy: bool = False,
inversion_region: tuple[float, float, float, float] | None = None,
figsize: tuple[float, float] = (5, 3.5),
) -> None:
"""
Expand All @@ -206,6 +207,8 @@ def plot_convergence(
list of iteration execution times, by default None
logy : bool, optional
choose whether to plot y axis in log scale, by default False
inversion_region : tuple[float, float, float, float] | None, optional
inside region of inversion, by default None
figsize : tuple[float, float], optional
width and height of figure, by default (5, 3.5)
"""
Expand All @@ -223,23 +226,32 @@ def plot_convergence(
# get misfit data at end of each iteration
cols = [s for s in results.columns.to_list() if "_final_misfit" in s]
iters = len(cols)
final_misfits = [utils.rmse(results[i]) for i in cols]
if inversion_region is not None:
misfits = [utils.rmse(results[results.inside][i]) for i in cols]
starting_misfit = utils.rmse(results[results.inside]["iter_1_initial_misfit"])
else:
misfits = [utils.rmse(results[i]) for i in cols]
starting_misfit = utils.rmse(results["iter_1_initial_misfit"])
# add starting misfit to the beginning of the list
misfits.insert(0, starting_misfit)

_fig, ax1 = plt.subplots(figsize=figsize)
plt.title("Inversion convergence")
ax1.plot(range(iters), final_misfits, "b-")
ax1.plot(range(iters + 1), misfits, "b-")
ax1.set_xlabel("Iteration")
if logy:
ax1.set_yscale("log")
ax1.set_ylabel("RMS (mGal)", color="b")
ax1.tick_params(axis="y", colors="b", which="both")

if iter_times is not None:
iter_times.insert(0, 0)
ax2 = ax1.twinx()
ax2.plot(range(iters), np.cumsum(iter_times), "g-")
ax2.plot(range(iters + 1), np.cumsum(iter_times), "g-")
ax2.set_ylabel("Cumulative time (s)", color="g")
ax2.tick_params(axis="y", colors="g")
ax2.grid(False)

plt.tight_layout()


Expand Down

0 comments on commit 52eebdd

Please sign in to comment.