Skip to content

Commit 2097b36

Browse files
committed
Fix Altair chart concatenation and stubs
1 parent a6410a2 commit 2097b36

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

docs/source/how_to/how_to_change_plotting_backend.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
"The returned figure object is an [`altair.Chart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html).\n",
6767
"\n",
6868
"```{note}\n",
69-
"In case of grid plots (such as `convergence_plot` or `slice_plot`), the returned object is an [`altair.VConcatChart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.VConcatChart.html).\n",
69+
"In case of grid plots (such as `convergence_plot` or `slice_plot`), the returned object is either an [`altair.Chart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) if there is only one subplot, an [`altair.HConcatChart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.HConcatChart.html) if there is only one row, or an [`altair.VConcatChart`](https://altair-viz.github.io/user_guide/generated/toplevel/altair.VConcatChart.html) otherwise.\n",
7070
"```\n",
7171
"\n",
7272
":::\n",
@@ -207,7 +207,7 @@
207207
],
208208
"metadata": {
209209
"kernelspec": {
210-
"display_name": "Python 3",
210+
"display_name": "optimagic",
211211
"language": "python",
212212
"name": "python3"
213213
},
@@ -221,7 +221,7 @@
221221
"name": "python",
222222
"nbconvert_exporter": "python",
223223
"pygments_lexer": "ipython3",
224-
"version": "3.10.18"
224+
"version": "3.10.17"
225225
}
226226
},
227227
"nbformat": 4,

src/optimagic/optimization/history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def params_data(
207207
"""
208208
wide = pd.DataFrame(self.flat_params, columns=self.flat_param_names)
209209
wide["task"] = _task_to_categorical(self.task)
210-
wide["fun"] = self.fun
210+
wide["fun"] = self.fun # type: ignore[assignment]
211211

212212
# If requested, we collapse the batches and only keep the parameters that led to
213213
# the minimal (or maximal) function value in each batch.

src/optimagic/visualization/backends.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -729,15 +729,16 @@ def _grid_line_plot_altair(
729729
plot_title: str | None,
730730
marker_list: list[MarkerData] | None,
731731
make_subplot_kwargs: dict[str, Any] | None = None,
732-
) -> "alt.VConcatChart":
732+
) -> "alt.Chart | alt.HConcatChart | alt.VConcatChart":
733733
"""Create a grid of line plots using Altair.
734734
735735
Args:
736736
...: All other argument descriptions can be found in the docstring of the
737737
`grid_line_plot` function.
738738
739739
Returns:
740-
An Altair VConcatChart object.
740+
An Altair Chart if the grid contains only one subplot, an Altair HConcatChart
741+
if 'n_rows' is 1, otherwise an Altair VConcatChart.
741742
742743
"""
743744
import altair as alt
@@ -774,10 +775,16 @@ def _grid_line_plot_altair(
774775
charts.append(chart_row)
775776

776777
row_selections = [
777-
alt.selection_interval(bind="scales", encodings=["y"]) for _ in range(n_rows)
778+
alt.selection_interval(
779+
bind="scales", encodings=["y"], name=f"share_y_row{row_idx}"
780+
)
781+
for row_idx in range(n_rows)
778782
]
779783
col_selections = [
780-
alt.selection_interval(bind="scales", encodings=["x"]) for _ in range(n_cols)
784+
alt.selection_interval(
785+
bind="scales", encodings=["x"], name=f"share_x_col{col_idx}"
786+
)
787+
for col_idx in range(n_cols)
781788
]
782789

783790
for row_idx, row in enumerate(charts):
@@ -790,13 +797,25 @@ def _grid_line_plot_altair(
790797
params.append(row_selections[row_idx])
791798
else:
792799
# Use independent y-axes for each subplot
793-
params.append(alt.selection_interval(bind="scales", encodings=["y"]))
800+
params.append(
801+
alt.selection_interval(
802+
bind="scales",
803+
encodings=["x", "y"],
804+
name=f"ind_y_row{row_idx}_col{col_idx}",
805+
)
806+
)
794807
if share_x:
795808
# Share x-axis for all subplots in the same column
796809
params.append(col_selections[col_idx])
797810
else:
798811
# Use independent x-axes for each subplot
799-
params.append(alt.selection_interval(bind="scales", encodings=["x"]))
812+
params.append(
813+
alt.selection_interval(
814+
bind="scales",
815+
encodings=["x", "y"],
816+
name=f"ind_x_row{row_idx}_col{col_idx}",
817+
)
818+
)
800819
chart = chart.add_params(*params)
801820

802821
if share_y and col_idx > 0:
@@ -808,7 +827,20 @@ def _grid_line_plot_altair(
808827

809828
charts[row_idx][col_idx] = chart
810829

811-
grid_chart = alt.vconcat(*[alt.hconcat(*row) for row in charts])
830+
row_charts = []
831+
for row in charts:
832+
row_chart: alt.Chart | alt.HConcatChart
833+
if len(row) == 1:
834+
row_chart = row[0]
835+
else:
836+
row_chart = alt.hconcat(*row)
837+
row_charts.append(row_chart)
838+
839+
grid_chart: alt.Chart | alt.HConcatChart | alt.VConcatChart
840+
if len(row_charts) == 1:
841+
grid_chart = row_charts[0]
842+
else:
843+
grid_chart = alt.vconcat(*row_charts)
812844

813845
if plot_title is not None:
814846
grid_chart = grid_chart.properties(title=plot_title)

src/optimagic/visualization/slice_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _get_plot_data(
292292
metadata.append(meta)
293293

294294
plot_data = pd.DataFrame(metadata)
295-
plot_data["Function Value"] = func_values
295+
plot_data["Function Value"] = func_values # type: ignore[assignment]
296296

297297
return plot_data, internal_params
298298

0 commit comments

Comments
 (0)