Skip to content

Commit

Permalink
feat: add plot_title kwarg for plot_cv_scores
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed May 7, 2024
1 parent ebf6732 commit 3cda4e0
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def plot_cv_scores(
logy: bool = False,
param_name: str = "Hyperparameter",
figsize: tuple[float, float] = (5, 3.5),
plot_title: str | None = None,
) -> None:
"""
plot a graph of cross-validation scores vs hyperparameter values
Expand All @@ -146,6 +147,8 @@ def plot_cv_scores(
name to give for the parameters, by default "Hyperparameter"
figsize : tuple[float, float], optional
size of the figure, by default (5, 3.5)
plot_title : str | None, optional
title of figure, by default None
"""
# Check if seaborn is installed
if sns is None:
Expand All @@ -163,7 +166,10 @@ def plot_cv_scores(
best = df.scores.argmin()

plt.figure(figsize=figsize)
plt.title(f"{param_name} Cross-validation")
if plot_title is not None:
plt.title(plot_title)
else:
plt.title(f"{param_name} Cross-validation")
plt.plot(df.parameters, df.scores, marker="o")
plt.plot(
df.parameters.iloc[best],
Expand Down

0 comments on commit 3cda4e0

Please sign in to comment.