diff --git a/functime/plotting.py b/functime/plotting.py index 393290d4..2ccfba6c 100644 --- a/functime/plotting.py +++ b/functime/plotting.py @@ -1,6 +1,5 @@ -from typing import Optional, Union +from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import plotly.express as px import plotly.graph_objects as go import polars as pl @@ -23,6 +22,178 @@ def _remove_legend_duplicates(fig: go.Figure) -> go.Figure: return fig +def _set_subplot_default_kwargs(kwargs: dict, n_rows: int, n_cols: int) -> dict: + """ + Sets or adjusts plot layout properties based on the number of rows and columns in subplots, + ensuring a fixed size for subplots and additional space for titles and other elements. + Default values are applied only if not already specified by the user. + + Parameters + ---------- + kwargs : dict + The original keyword arguments dictionary passed to the plotting function. + n_rows : int + Number of rows in the subplot. + n_cols : int + Number of columns in the subplot. + + Returns + -------- + dict : Updated keyword arguments with adjusted layout properties. + + The function ensures the following: + - A fixed size for each subplot. + - Additional space for plot titles and other elements. + - The overall figure size is dynamically adjusted based on the subplot configuration. + """ + # Fixed size for each subplot + subplot_width = 250 # width for each subplot column + subplot_height = 200 # height for each subplot row + + # Additional space for titles and other elements + additional_space_vertical = 100 # space for titles, labels, etc. + additional_space_horizontal = 100 # additional horizontal space if needed + + # Calculate total width and height + total_width = subplot_width * n_cols + additional_space_horizontal + total_height = subplot_height * n_rows + additional_space_vertical + + # Apply defaults if not already specified by the user + kwargs.setdefault("width", total_width) + kwargs.setdefault("height", total_height) + kwargs.setdefault("template", "plotly_white") + + return kwargs + + +def _calculate_subplot_n_rows(n_series: int, n_cols: int) -> int: + if n_series <= 0: + raise ValueError("n_series must be greater than 0.") + if n_cols <= 0: + raise ValueError("n_cols must be greater than 0.") + n_rows = n_series // n_cols + if n_series % n_cols != 0: + n_rows += 1 + + return n_rows + + +def _get_subplot_grid_position(i: int, n_cols: int) -> Tuple[int, int]: + row = i // n_cols + 1 + col = i % n_cols + 1 + return row, col + + +def _prepare_data_for_subplots( + y: Union[pl.DataFrame, pl.LazyFrame], + n_series: int, + last_n: int, + seed: Union[int, None] = None, +) -> Tuple[pl.Series, int, pl.DataFrame]: + """ + Prepares data for plotting by selecting and sampling entities and getting the recent observations for plotting. + + Parameters + ---------- + y : Union[pl.DataFrame, pl.LazyFrame] + Panel DataFrame of observed values. + n_series : int + Number of entities / time-series to plot. + last_n : int + Plot `last_n` most recent values in `y`. + seed : Union[int, None], optional + Random seed for sampling entities / time-series, by default None. + + Returns + ------- + Tuple[pl.Series, int, pl.DataFrame] + Sampled entities, n_series to plot, and filtered DataFrame. + """ + entity_col = y.columns[0] + + if isinstance(y, pl.DataFrame): + y = y.lazy() + + # Get unique entities + entities = y.select(pl.col(entity_col).unique(maintain_order=True)).collect() + + # If n_series is higher than max unique entities, use max entities + if entities.height < n_series: + n_series = entities.height + + # Sample entities + entities_sample = entities.to_series().sample(n_series, seed=seed) + + # Get most recent observations + y_filtered = ( + y.filter(pl.col(entity_col).is_in(entities_sample)) + .group_by(entity_col) + .tail(last_n) + .collect() + ) + + return entities_sample, n_series, y_filtered + + +def _add_scatter_traces_to_subplots( + fig: go.Figure, + ts: Union[pl.DataFrame, pl.LazyFrame], + ts_pred: Union[pl.DataFrame, pl.LazyFrame, None], + entity_id: Any, + row: int, + col: int, + plot_params: Dict[str, Any], +) -> None: + """ + Adds traces to a specific subplot in the figure for actual and optionally predicted data. + + Parameters + ---------- + fig : go.Figure + The Plotly figure object containing the subplots. + ts : Union[pl.DataFrame, pl.LazyFrame] + The primary time series data to plot. + ts_pred : Union[pl.DataFrame, pl.LazyFrame, None] + The secondary time series data to plot (e.g., predictions or backtests). + If None, only ts is plotted. + entity_id : Any + The identifier for the current entity being plotted. + row : int + The row position in the subplot grid. + col : int + The column position in the subplot grid. + plot_params : Dict[str, Any] + Dictionary containing parameters for plotting, such as color and name. + + Returns + ------- + None + """ + entity_col, time_col, target_col = ts.columns[:3] + + ts_trace = go.Scatter( + x=ts.filter(pl.col(entity_col) == entity_id).get_column(time_col), + y=ts.filter(pl.col(entity_col) == entity_id).get_column(target_col), + name=plot_params.get("ts_name", "Actual"), + legendgroup=plot_params.get("ts_name", "Actual"), + line=dict(color=plot_params.get("ts_color", "blue")), + ) + + fig.add_trace(ts_trace, row=row, col=col) + + if ts_pred is not None: + if isinstance(ts_pred, pl.LazyFrame): + ts_pred = ts_pred.collect() + ts_pred_trace = go.Scatter( + x=ts_pred.filter(pl.col(entity_col) == entity_id).get_column(time_col), + y=ts_pred.filter(pl.col(entity_col) == entity_id).get_column(target_col), + name=plot_params.get("ts_pred_name", "Forecast"), + legendgroup=plot_params.get("ts_pred_name", "Forecast"), + line=dict(color=plot_params.get("ts_pred_color", "red"), dash="dash"), + ) + fig.add_trace(ts_pred_trace, row=row, col=col) + + def plot_entities( y: Union[pl.DataFrame, pl.LazyFrame], **kwargs, @@ -51,24 +222,28 @@ def plot_entities( title = kwargs.pop("title", "Entities counts") template = kwargs.pop("template", "plotly_white") - return px.bar( - data_frame=entity_counts, - x="count", - y=entity_col, - orientation="h", - ).update_layout( + fig = go.Figure( + go.Bar( + x=entity_counts.get_column("count"), + y=entity_counts.get_column(entity_col), + orientation="h", + ) + ) + fig.update_layout( height=height, title=title, template=template, **kwargs, ) + return fig + def plot_panel( y: Union[pl.DataFrame, pl.LazyFrame], *, n_series: int = 10, - seed: int | None = None, + seed: Union[int, None] = None, n_cols: int = 2, last_n: int = DEFAULT_LAST_N, **kwargs, @@ -83,7 +258,7 @@ def plot_panel( n_series : int Number of entities / time-series to plot. Defaults to 10. - seed : int | None + seed : Union[int, None] Random seed for sampling entities / time-series. Defaults to None. n_cols : int @@ -98,48 +273,48 @@ def plot_panel( figure : plotly.graph_objects.Figure Plotly subplots. """ - entity_col, time_col, target_col = y.columns[:3] - - if isinstance(y, pl.DataFrame): - y = y.lazy() + entity_col = y.columns[0] - entities = y.select(pl.col(entity_col).unique(maintain_order=True)).collect() + # Get sampled entities, check validity of n_series + # and filter the y df to contain last_n values + entities_sample, n_series, y_filtered = _prepare_data_for_subplots( + y=y, + n_series=n_series, + last_n=last_n, + seed=seed, + ) - entities_sample = entities.to_series().sample(n_series, seed=seed) + # Define grid and make subplots + n_rows = _calculate_subplot_n_rows(n_series=n_series, n_cols=n_cols) + fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities_sample) - # Get most recent observations - y = ( - y.filter(pl.col(entity_col).is_in(entities_sample)) - .group_by(entity_col) - .tail(last_n) - .collect() - ) + # Define default names and colors to be used + plot_params = { + "ts_name": "Time-series", + "ts_color": COLOR_PALETTE["actual"], + } - # Organize subplots - n_rows = n_series // n_cols - row_idx = np.repeat(range(n_rows), n_cols) - fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities) - - for i, entity_id in enumerate(entities): - ts = y.filter(pl.col(entity_col) == entity_id) - row = row_idx[i] + 1 - col = i % n_cols + 1 - # Plot actual - fig.add_trace( - go.Scatter( - x=ts.get_column(time_col), - y=ts.get_column(target_col), - name="Time-series", - legendgroup="Time-series", - line=dict(color=COLOR_PALETTE["forecast"]), - ), + # Loop and plot each sampled entity + for i, entity_id in enumerate(entities_sample): + ts = y_filtered.filter(pl.col(entity_col) == entity_id) + # Get the subplot position for the ts + row, col = _get_subplot_grid_position(i=i, n_cols=n_cols) + # Plot trace(s) for the timeseries + _add_scatter_traces_to_subplots( + fig=fig, + ts=ts, + ts_pred=None, + entity_id=entity_id, row=row, col=col, + plot_params=plot_params, ) - template = kwargs.pop("template", "plotly_white") + # Set default kwargs for plotting if user did not provide these + kwargs = _set_subplot_default_kwargs(kwargs=kwargs, n_rows=n_rows, n_cols=n_cols) - fig.update_layout(template=template, **kwargs) + # Tidy up the plot + fig.update_layout(**kwargs) fig = _remove_legend_duplicates(fig) return fig @@ -149,7 +324,7 @@ def plot_forecasts( y_pred: pl.DataFrame, *, n_series: int = 10, - seed: int | None = None, + seed: Union[int, None] = None, n_cols: int = 2, last_n: int = DEFAULT_LAST_N, **kwargs, @@ -166,7 +341,7 @@ def plot_forecasts( n_series : int Number of entities / time-series to plot. Defaults to 10. - seed : int | None + seed : Union[int, None] Random seed for sampling entities / time-series. Defaults to None. n_cols : int @@ -181,62 +356,49 @@ def plot_forecasts( figure : plotly.graph_objects.Figure Plotly subplots. """ - entity_col, time_col, target_col = y_true.columns[:3] - - if isinstance(y_true, pl.DataFrame): - y_true = y_true.lazy() - - # Get most recent observations - entities = y_true.select(pl.col(entity_col).unique(maintain_order=True)).collect() - - entities_sample = entities.to_series().sample(n_series, seed=seed) - - # Get most recent observations - y = ( - y_true.filter(pl.col(entity_col).is_in(entities_sample)) - .group_by(entity_col) - .tail(last_n) - .collect() + entity_col = y_true.columns[0] + + # Get sampled entities, check validity of n_series + # and filter the y df to contain last_n values + entities_sample, n_series, y_filtered = _prepare_data_for_subplots( + y=y_true, + n_series=n_series, + last_n=last_n, + seed=seed, ) - # Organize subplots - n_rows = n_series // n_cols - row_idx = np.repeat(range(n_rows), n_cols) - fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities) + # Define grid and make subplots + n_rows = _calculate_subplot_n_rows(n_series=n_series, n_cols=n_cols) + fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities_sample) + + plot_params = { + "ts_name": "Actual", + "ts_color": COLOR_PALETTE["actual"], + "ts_pred_name": "Forecast", + "ts_pred_color": COLOR_PALETTE["forecast"], + } - for i, entity_id in enumerate(entities): - ts = y.filter(pl.col(entity_col) == entity_id) + for i, entity_id in enumerate(entities_sample): + ts = y_filtered.filter(pl.col(entity_col) == entity_id) ts_pred = y_pred.filter(pl.col(entity_col) == entity_id) - row = row_idx[i] + 1 - col = i % n_cols + 1 - # Plot actual - fig.add_trace( - go.Scatter( - x=ts.get_column(time_col), - y=ts.get_column(target_col), - name="Actual", - legendgroup="Actual", - line=dict(color=COLOR_PALETTE["actual"]), - ), - row=row, - col=col, - ) - # Plot forecast - fig.add_trace( - go.Scatter( - x=ts_pred.get_column(time_col), - y=ts_pred.get_column(target_col), - name="Forecast", - legendgroup="Forecast", - line=dict(color=COLOR_PALETTE["forecast"], dash="dash"), - ), + # Get the subplot position for the ts + row, col = _get_subplot_grid_position(i=i, n_cols=n_cols) + # Plot trace(s) for the timeseries + _add_scatter_traces_to_subplots( + fig=fig, + ts=ts, + ts_pred=ts_pred, + entity_id=entity_id, row=row, col=col, + plot_params=plot_params, ) - template = kwargs.pop("template", "plotly_white") + # Set default kwargs for plotting if user did not provide these + kwargs = _set_subplot_default_kwargs(kwargs=kwargs, n_rows=n_rows, n_cols=n_cols) - fig.update_layout(template=template, **kwargs) + # Tidy up the plot + fig.update_layout(**kwargs) fig = _remove_legend_duplicates(fig) return fig @@ -246,7 +408,7 @@ def plot_backtests( y_preds: pl.DataFrame, *, n_series: int = 10, - seed: int | None = None, + seed: Union[int, None] = None, n_cols: int = 2, last_n: int = DEFAULT_LAST_N, **kwargs, @@ -263,7 +425,7 @@ def plot_backtests( n_series : int Number of entities / time-series to plot. Defaults to 10. - seed : int | None + seed : Union[int, None] Random seed for sampling entities / time-series. Defaults to None. n_cols : int @@ -278,62 +440,49 @@ def plot_backtests( figure : plotly.graph_objects.Figure Plotly subplots. """ - entity_col, time_col, target_col = y_true.columns[:3] - - if isinstance(y_true, pl.DataFrame): - y_true = y_true.lazy() - - # Get most recent observations - entities = y_true.select(pl.col(entity_col).unique(maintain_order=True)).collect() - - entities_sample = entities.to_series().sample(n_series, seed=seed) - - # Get most recent observations - y = ( - y_true.filter(pl.col(entity_col).is_in(entities_sample)) - .group_by(entity_col) - .tail(last_n) - .collect() + entity_col = y_true.columns[0] + + # Get sampled entities, check validity of n_series + # and filter the y df to contain last_n values + entities_sample, n_series, y_filtered = _prepare_data_for_subplots( + y=y_true, + n_series=n_series, + last_n=last_n, + seed=seed, ) - # Organize subplots - n_rows = n_series // n_cols - row_idx = np.repeat(range(n_rows), n_cols) - fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities) + # Define grid and make subplots + n_rows = _calculate_subplot_n_rows(n_series=n_series, n_cols=n_cols) + fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities_sample) - for i, entity_id in enumerate(entities): - ts = y.filter(pl.col(entity_col) == entity_id) + plot_params = { + "ts_name": "Actual", + "ts_color": COLOR_PALETTE["actual"], + "ts_pred_name": "Backtest", + "ts_pred_color": COLOR_PALETTE["backtest"], + } + + for i, entity_id in enumerate(entities_sample): + ts = y_filtered.filter(pl.col(entity_col) == entity_id) ts_pred = y_preds.filter(pl.col(entity_col) == entity_id) - row = row_idx[i] + 1 - col = i % n_cols + 1 - # Plot actual - fig.add_trace( - go.Scatter( - x=ts.get_column(time_col), - y=ts.get_column(target_col), - name="Actual", - legendgroup="Actual", - line=dict(color=COLOR_PALETTE["actual"]), - ), - row=row, - col=col, - ) - # Plot forecast - fig.add_trace( - go.Scatter( - x=ts_pred.get_column(time_col), - y=ts_pred.get_column(target_col), - name="Backtest", - legendgroup="Backtest", - line=dict(color=COLOR_PALETTE["backtest"], dash="dash"), - ), + # Get the subplot position for the ts + row, col = _get_subplot_grid_position(i=i, n_cols=n_cols) + # Plot trace(s) for the timeseries + _add_scatter_traces_to_subplots( + fig=fig, + ts=ts, + ts_pred=ts_pred, + entity_id=entity_id, row=row, col=col, + plot_params=plot_params, ) - template = kwargs.pop("template", "plotly_white") # noqa: F841 + # Set default kwargs for plotting if user did not provide these + kwargs = _set_subplot_default_kwargs(kwargs=kwargs, n_rows=n_rows, n_cols=n_cols) - fig.update_layout(template=template, **kwargs) + # Tidy up the plot + fig.update_layout(**kwargs) fig = _remove_legend_duplicates(fig) return fig @@ -368,7 +517,7 @@ def plot_residuals( y_resids = y_resids.with_columns(pl.col(target_col).alias("Residuals")).collect() fig = px.histogram( - y_resids, + y_resids.to_pandas(), x="Residuals", y="Residuals", color=entity_col, @@ -424,7 +573,10 @@ def plot_comet( mean_score = scores.get_column(scores.columns[-1]).mean() mean_cv = cvs.get_column(cvs.columns[-1]).mean() fig = px.scatter( - comet, x=cvs.columns[-1], y=scores.columns[-1], hover_data=entity_col + comet.to_pandas(), + x=cvs.columns[-1], + y=scores.columns[-1], + hover_data=entity_col, ) fig.add_hline(y=mean_score) fig.add_vline(x=mean_cv) @@ -473,7 +625,7 @@ def plot_fva( how="left", on=scores.columns[0], ) - fig = px.scatter(uplift, x=x_title, y=y_title, hover_data=entity_col) + fig = px.scatter(uplift.to_pandas(), x=x_title, y=y_title, hover_data=entity_col) deg45_line = { "type": "line", "yref": "paper", diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 00000000..17b82476 --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,99 @@ +import polars as pl +import pytest + +from functime import plotting + + +def test_set_subplot_default_kwargs_no_existing_kwargs(): + kwargs = {} + updated_kwargs = plotting._set_subplot_default_kwargs(kwargs, 2, 3) + + assert updated_kwargs["width"] == 250 * 3 + 100 # default width * cols + space + assert updated_kwargs["height"] == 200 * 2 + 100 # default height * rows + space + assert updated_kwargs["template"] == "plotly_white" + + +def test_set_subplot_default_kwargs_with_one_defined_kwarg(): + kwargs = {"width": 800, "some_other_kwarg": "value"} + updated_kwargs = plotting._set_subplot_default_kwargs(kwargs, 2, 3) + + assert updated_kwargs["width"] == 800 # Should remain unchanged + assert updated_kwargs["height"] == 200 * 2 + 100 # default height * rows + space + assert updated_kwargs["some_other_kwarg"] == "value" + + +@pytest.mark.parametrize( + "n_series, n_cols, expected_rows", + [ + (10, 2, 5), # 10 series in 2 columns > 5 rows + (10, 1, 10), # All series in one column > 10 rows + (10, 3, 4), # Series not exactly divisible by columns + (10, 10, 1), # Each series in its own column + (10, 15, 1), # More columns than series + ], +) +def test_calculate_subplot_n_rows(n_series, n_cols, expected_rows): + assert plotting._calculate_subplot_n_rows(n_series, n_cols) == expected_rows + + +@pytest.mark.parametrize( + "n_series, n_cols", + [ + (0, 2), # No series + (10, 0), # Zero columns + (-1, 2), # Negative series + (10, -2), # Negative columns + ], +) +def test_calculate_subplot_n_rows_errors(n_series, n_cols): + with pytest.raises(ValueError): + plotting._calculate_subplot_n_rows(n_series, n_cols) + + +def create_mock_dataframe(): + # Create a mock DataFrame for testing + data = { + "entity": ["A", "A", "B", "B", "C", "C"], + "time": [1, 2, 1, 2, 1, 2], + "value": [10, 20, 30, 40, 50, 60], + } + return pl.DataFrame(data) + + +@pytest.mark.parametrize( + "n_series, last_n, expected_entities", + [ + (2, 1, {"A", "B"}), # Test with 2 series, last 1 record + (3, 2, {"A", "B", "C"}), # Test with all series, last 2 records + (4, 2, {"A", "B", "C"}), # More series than available + ], +) +def test_prepare_data_for_subplots(n_series, last_n, expected_entities): + df = create_mock_dataframe() + entities_sample, _, y_filtered = plotting._prepare_data_for_subplots( + df, n_series, last_n, seed=1 + ) + + # Check if the correct entities are sampled + assert set(entities_sample) == expected_entities + + # Check if the data is correctly filtered + for entity in entities_sample: + assert y_filtered.filter(pl.col("entity") == entity).height <= last_n + + +@pytest.mark.parametrize( + "i, n_cols, expected_row_col", + [ + (0, 3, (1, 1)), # First series & 3 cols pos = 1,1 + (1, 3, (1, 2)), # Second series & 3 cols pos = 1,2 + (2, 3, (1, 3)), # Third series & 3 cols pos = 1,3 + (3, 3, (2, 1)), # Fourth series, start of second row + (26, 3, (9, 3)), # 27th series in a 3-column layout + (27, 3, (10, 1)), # 28th series, starts a new row + (160, 7, (23, 7)), # 161st series, ends in the last col + (161, 7, (24, 1)), # 162nd series, starts a new row + ], +) +def test_get_subplot_grid_position(i, n_cols, expected_row_col): + assert plotting._get_subplot_grid_position(i, n_cols) == expected_row_col