Skip to content

Commit

Permalink
feat: add 2 parameter cross validation plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 27, 2023
1 parent 960eb44 commit 2d1269e
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,101 @@
from invert4geom import utils


def plot_2_parameter_cv_scores(
scores: list[float],
parameter_pairs: list[tuple[float, float]],
logx: bool = False,
logy: bool = False,
param_names: tuple[str, str] = ("Hyperparameter 1", "Hyperparameter 2"),
figsize: tuple[float, float] = (5, 3.5),
cmap: str = "viridis",
) -> None:
"""
plot a scatter plot graph with x axis equal to parameter 1, y axis equal to
parameter 2, and points colored by cross-validation scores.
Parameters
----------
scores : list[float]
score values
parameter_pairs : list[float]
parameter values
logx, logy : bool, optional
make the x or y axes log scale, by default False
param_names : tuple[str, str], optional
name to give for the parameters, by default "Hyperparameter"
figsize : tuple[float, float], optional
size of the figure, by default (5, 3.5)
cmap : str, optional
matplotlib colormap for scores, by default "viridis"
"""
# Check if seaborn is installed
if sns is None:
msg = "Missing optional dependency 'seaborn' required for plotting."
raise ImportError(msg)
sns.set_theme()
# Check if matplotlib is installed
if plt is None:
msg = "Missing optional dependency 'matplotlib' required for plotting."
raise ImportError(msg)

df = pd.DataFrame(
{
"scores": scores,
param_names[0]: [
parameter_pairs[i][0] for i in list(range(len(parameter_pairs)))
],
param_names[1]: [
parameter_pairs[i][1] for i in list(range(len(parameter_pairs)))
],
}
)
df = df.sort_values(by="scores")

best = df.iloc[0]

plt.figure(figsize=figsize)
plt.title("Two parameter cross-validation")

grid = df.set_index([param_names[1], param_names[0]]).to_xarray().scores
grid.plot(cmap=cmap)
# plt.contourf(
# df[param_names[0]],
# df[param_names[1]],
# Z = grid,
# cmap = cmap,
# )
plt.scatter(
df[param_names[0]], # pylint: disable=unsubscriptable-object
df[param_names[1]], # pylint: disable=unsubscriptable-object
# c = df.scores,
# cmap = cmap,
# marker="o",
marker=".",
color="black",
)
plt.plot(
best[param_names[0]],
best[param_names[1]],
"s",
markersize=10,
color=sns.color_palette()[3],
label="Minimum",
)
plt.legend(
loc="upper right",
)

if logx:
plt.xscale("log")
if logy:
plt.yscale("log")
plt.xlabel(param_names[0])
plt.ylabel(param_names[1])
# plt.colorbar()
plt.tight_layout()


def plot_cv_scores(
scores: list[float],
parameters: list[float],
Expand Down

0 comments on commit 2d1269e

Please sign in to comment.