diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 2eb0e78b..859630bb 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -43,7 +43,8 @@ def plot(data: pd.DataFrame, time_value: date = None, plot_type: str = "choropleth", combine_megacounties: bool = True, - **kwargs: Any) -> figure.Figure: + ax: axes.Axes = None, + **kwargs: Any) -> axes.Axes: """Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map. Projections used for plotting: @@ -71,6 +72,9 @@ def plot(data: pd.DataFrame, bubble but have the region displayed in white, and values above the mean + 3 std dev are binned into the highest bubble. Bubbles are scaled by area. + A Matplotlib Axes object can be provided to plot the maps onto an existing figure. Otherwise, + a new Axes object will be created and returned. + :param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`. :param time_value: If multiple days of data are present in ``data``, map only values from this day. Defaults to plotting the most recent day of data in ``data``. @@ -79,7 +83,8 @@ def plot(data: pd.DataFrame, Defaults to `True`. :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``. :param plot_type: Type of plot to create. Either choropleth (default) or bubble map. - :return: Matplotlib figure object. + :param ax: Optional matplotlib axis to plot on. + :return: Matplotlib axes object. """ if plot_type not in {"choropleth", "bubble"}: @@ -92,26 +97,31 @@ def plot(data: pd.DataFrame, kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"]) kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6)) - fig, ax = _plot_background_states(kwargs["figsize"]) + ax = _plot_background_states(kwargs["figsize"]) if ax is None \ + else _plot_background_states(ax=ax) + ax.axis("off") ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}") if plot_type == "choropleth": _plot_choro(ax, day_data, combine_megacounties, **kwargs) else: _plot_bubble(ax, day_data, geo_type, **kwargs) - return fig + return ax def plot_choropleth(data: pd.DataFrame, time_value: date = None, combine_megacounties: bool = True, - **kwargs: Any) -> figure.Figure: + **kwargs: Any) -> axes.Axes: """Plot choropleths for a signal. This method is deprecated and has been generalized to plot(). + .. deprecated:: 0.1.1 + Use ``plot()`` instead. + :param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`. :param time_value: If multiple days of data are present in ``data``, map only values from this day. Defaults to plotting the most recent day of data in ``data``. :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``. - :return: Matplotlib figure object. + :return: Matplotlib axes object. """ warnings.warn("Function `plot_choropleth` is deprecated. Use `plot()` instead.") return plot(data, time_value, "choropleth", combine_megacounties, **kwargs) @@ -286,21 +296,22 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(figsize: tuple) -> tuple: +def _plot_background_states(figsize: tuple = (12.8, 9.6), ax: axes.Axes = None) -> axes.Axes: """Plot US states in light grey as the background for other plots. - :param figsize: Dimensions of plot. - :return: Matplotlib figure and axes. + :param figsize: Dimensions of plot. Ignored if ax is provided. + :param ax: Optional matplotlib axis to plot on. + :return: Matplotlib axes. """ - fig, ax = plt.subplots(1, figsize=figsize) - ax.axis("off") + if ax is None: + fig, ax = plt.subplots(1, figsize=figsize) state_shapefile_path = pkg_resources.resource_filename(__name__, SHAPEFILE_PATHS["state"]) state = gpd.read_file(state_shapefile_path) for state in _project_and_transform(state, "STATEFP"): state.plot(color="0.9", ax=ax, edgecolor="0.8", linewidth=0.5) - ax.set_xlim(plt.xlim()) - ax.set_ylim(plt.ylim()) - return fig, ax + ax.set_xlim(ax.get_xlim()) + ax.set_ylim(ax.get_ylim()) + return ax def _project_and_transform(data: gpd.GeoDataFrame, diff --git a/Python-packages/covidcast-py/tests/test_plotting.py b/Python-packages/covidcast-py/tests/test_plotting.py index 8d458cfb..01ac6e71 100644 --- a/Python-packages/covidcast-py/tests/test_plotting.py +++ b/Python-packages/covidcast-py/tests/test_plotting.py @@ -3,6 +3,7 @@ from unittest.mock import patch import matplotlib +from matplotlib import pyplot as plt import platform import geopandas as gpd import numpy as np @@ -49,21 +50,22 @@ def test_plot(mock_metadata): test_county["value"] = test_county.value.astype("float") # w/o megacounties - no_mega_fig1 = plotting.plot(test_county, - time_value=date(2020, 8, 4), - combine_megacounties=False) + plotting.plot(test_county, time_value=date(2020, 8, 4), combine_megacounties=False) + no_mega_fig1 = plt.gcf() # give margin of +-2 for floating point errors and weird variations (1 isn't consistent) assert np.allclose(_convert_to_array(no_mega_fig1), expected["no_mega_1"], atol=2, rtol=0) - no_mega_fig2 = plotting.plot_choropleth(test_county, - cmap="viridis", - figsize=(5, 5), - edgecolor="0.8", - combine_megacounties=False) + plotting.plot_choropleth(test_county, + cmap="viridis", + figsize=(5, 5), + edgecolor="0.8", + combine_megacounties=False) + no_mega_fig2 = plt.gcf() assert np.allclose(_convert_to_array(no_mega_fig2), expected["no_mega_2"], atol=2, rtol=0) # w/ megacounties - mega_fig = plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4)) + plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4)) + mega_fig = plt.gcf() # give margin of +-2 for floating point errors and weird variations (1 isn't consistent) assert np.allclose(_convert_to_array(mega_fig), expected["mega"], atol=2, rtol=0) @@ -72,7 +74,8 @@ def test_plot(mock_metadata): os.path.join(CURRENT_PATH, "reference_data/test_input_state_signal.csv"), dtype=str) test_state["time_value"] = test_state.time_value.astype("datetime64[D]") test_state["value"] = test_state.value.astype("float") - state_fig = plotting.plot(test_state) + plotting.plot(test_state) + state_fig = plt.gcf() assert np.allclose(_convert_to_array(state_fig), expected["state"], atol=2, rtol=0) # test MSA @@ -80,12 +83,13 @@ def test_plot(mock_metadata): os.path.join(CURRENT_PATH, "reference_data/test_input_msa_signal.csv"), dtype=str) test_msa["time_value"] = test_msa.time_value.astype("datetime64[D]") test_msa["value"] = test_msa.value.astype("float") - msa_fig = plotting.plot(test_msa) + plotting.plot(test_msa) + msa_fig = plt.gcf() assert np.allclose(_convert_to_array(msa_fig), expected["msa"], atol=2, rtol=0) # test bubble - msa_bubble_fig = plotting.plot(test_msa, plot_type="bubble") - from matplotlib import pyplot as plt + plotting.plot(test_msa, plot_type="bubble") + msa_bubble_fig = plt.gcf() assert np.allclose(_convert_to_array(msa_bubble_fig), expected["msa_bubble"], atol=2, rtol=0)