Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions Python-packages/covidcast-py/covidcast/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def plot(data: pd.DataFrame,
time_value: date = None,
plot_type: str = "choropleth",
combine_megacounties: bool = True,
**kwargs: Any) -> figure.Figure:
ax: axes.Axes = None,
**kwargs: Any) -> axes.Axes:
"""Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map.

Projections used for plotting:
Expand Down Expand Up @@ -71,6 +72,9 @@ def plot(data: pd.DataFrame,
bubble but have the region displayed in white, and values above the mean + 3 std dev are binned
into the highest bubble. Bubbles are scaled by area.

A Matplotlib Axes object can be provided to plot the maps onto an existing figure. Otherwise,
a new Axes object will be created and returned.

:param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`.
:param time_value: If multiple days of data are present in ``data``, map only values from this
day. Defaults to plotting the most recent day of data in ``data``.
Expand All @@ -79,7 +83,8 @@ def plot(data: pd.DataFrame,
Defaults to `True`.
:param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``.
:param plot_type: Type of plot to create. Either choropleth (default) or bubble map.
:return: Matplotlib figure object.
:param ax: Optional matplotlib axis to plot on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the docstring should say what happens if ax is not provided (axes are made, states are plotted). Maybe that goes in the text above, such as in a paragraph after the one-line summary explaining that the maps are plotted on US states by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I had a bug here, the background states are always plotted and I had left that statement out. I've also added a sentence describing the return behavior.

:return: Matplotlib axes object.

"""
if plot_type not in {"choropleth", "bubble"}:
Expand All @@ -92,26 +97,31 @@ def plot(data: pd.DataFrame,
kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"])
kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6))

fig, ax = _plot_background_states(kwargs["figsize"])
ax = _plot_background_states(kwargs["figsize"]) if ax is None \
else _plot_background_states(ax=ax)
ax.axis("off")
ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}")
if plot_type == "choropleth":
_plot_choro(ax, day_data, combine_megacounties, **kwargs)
else:
_plot_bubble(ax, day_data, geo_type, **kwargs)
return fig
return ax


def plot_choropleth(data: pd.DataFrame,
time_value: date = None,
combine_megacounties: bool = True,
**kwargs: Any) -> figure.Figure:
**kwargs: Any) -> axes.Axes:
"""Plot choropleths for a signal. This method is deprecated and has been generalized to plot().
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to use the deprecated directive here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. This also made me realize our changelog is a bit incorrect, will make another PR to fix that.


.. deprecated:: 0.1.1
Use ``plot()`` instead.

:param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`.
:param time_value: If multiple days of data are present in ``data``, map only values from this
day. Defaults to plotting the most recent day of data in ``data``.
:param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``.
:return: Matplotlib figure object.
:return: Matplotlib axes object.
"""
warnings.warn("Function `plot_choropleth` is deprecated. Use `plot()` instead.")
return plot(data, time_value, "choropleth", combine_megacounties, **kwargs)
Expand Down Expand Up @@ -286,21 +296,22 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs:
ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1))


def _plot_background_states(figsize: tuple) -> tuple:
def _plot_background_states(figsize: tuple = (12.8, 9.6), ax: axes.Axes = None) -> axes.Axes:
"""Plot US states in light grey as the background for other plots.

:param figsize: Dimensions of plot.
:return: Matplotlib figure and axes.
:param figsize: Dimensions of plot. Ignored if ax is provided.
:param ax: Optional matplotlib axis to plot on.
:return: Matplotlib axes.
"""
fig, ax = plt.subplots(1, figsize=figsize)
ax.axis("off")
if ax is None:
fig, ax = plt.subplots(1, figsize=figsize)
state_shapefile_path = pkg_resources.resource_filename(__name__, SHAPEFILE_PATHS["state"])
state = gpd.read_file(state_shapefile_path)
for state in _project_and_transform(state, "STATEFP"):
state.plot(color="0.9", ax=ax, edgecolor="0.8", linewidth=0.5)
ax.set_xlim(plt.xlim())
ax.set_ylim(plt.ylim())
return fig, ax
ax.set_xlim(ax.get_xlim())
ax.set_ylim(ax.get_ylim())
return ax


def _project_and_transform(data: gpd.GeoDataFrame,
Expand Down
30 changes: 17 additions & 13 deletions Python-packages/covidcast-py/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch

import matplotlib
from matplotlib import pyplot as plt
import platform
import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -49,21 +50,22 @@ def test_plot(mock_metadata):
test_county["value"] = test_county.value.astype("float")

# w/o megacounties
no_mega_fig1 = plotting.plot(test_county,
time_value=date(2020, 8, 4),
combine_megacounties=False)
plotting.plot(test_county, time_value=date(2020, 8, 4), combine_megacounties=False)
no_mega_fig1 = plt.gcf()
# give margin of +-2 for floating point errors and weird variations (1 isn't consistent)
assert np.allclose(_convert_to_array(no_mega_fig1), expected["no_mega_1"], atol=2, rtol=0)

no_mega_fig2 = plotting.plot_choropleth(test_county,
cmap="viridis",
figsize=(5, 5),
edgecolor="0.8",
combine_megacounties=False)
plotting.plot_choropleth(test_county,
cmap="viridis",
figsize=(5, 5),
edgecolor="0.8",
combine_megacounties=False)
no_mega_fig2 = plt.gcf()
assert np.allclose(_convert_to_array(no_mega_fig2), expected["no_mega_2"], atol=2, rtol=0)

# w/ megacounties
mega_fig = plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4))
plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4))
mega_fig = plt.gcf()
# give margin of +-2 for floating point errors and weird variations (1 isn't consistent)
assert np.allclose(_convert_to_array(mega_fig), expected["mega"], atol=2, rtol=0)

Expand All @@ -72,20 +74,22 @@ def test_plot(mock_metadata):
os.path.join(CURRENT_PATH, "reference_data/test_input_state_signal.csv"), dtype=str)
test_state["time_value"] = test_state.time_value.astype("datetime64[D]")
test_state["value"] = test_state.value.astype("float")
state_fig = plotting.plot(test_state)
plotting.plot(test_state)
state_fig = plt.gcf()
assert np.allclose(_convert_to_array(state_fig), expected["state"], atol=2, rtol=0)

# test MSA
test_msa = pd.read_csv(
os.path.join(CURRENT_PATH, "reference_data/test_input_msa_signal.csv"), dtype=str)
test_msa["time_value"] = test_msa.time_value.astype("datetime64[D]")
test_msa["value"] = test_msa.value.astype("float")
msa_fig = plotting.plot(test_msa)
plotting.plot(test_msa)
msa_fig = plt.gcf()
assert np.allclose(_convert_to_array(msa_fig), expected["msa"], atol=2, rtol=0)

# test bubble
msa_bubble_fig = plotting.plot(test_msa, plot_type="bubble")
from matplotlib import pyplot as plt
plotting.plot(test_msa, plot_type="bubble")
msa_bubble_fig = plt.gcf()
assert np.allclose(_convert_to_array(msa_bubble_fig), expected["msa_bubble"], atol=2, rtol=0)


Expand Down