In [2]:
%%time
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import us, os, imageio
from datetime import date, datetime
from matplotlib.animation import FuncAnimation
from natsort import natsorted

# Find number of days since "date"
def days_since(s_date, end_date):
    # If county is missing data
    if not isinstance(s_date, str):
        return -1
    delta = end_date - date.fromisoformat(s_date)
    
    # Case when first case is after end_date
    if delta.days < 0:
        return -1
    
    return delta.days

# Load shapefile data
def load_shapefile(filepath):
    shapefile = "shapefiles/c_02jn20.shp"
    map_df = gpd.read_file(shapefile)

    # Filter out Alaska, Puerto Rico, and Hawai'i
    filter_indices = map_df[(map_df['STATE'] == 'AK') | (map_df['STATE'] == 'PR') | (map_df['STATE'] == 'HI')].index
    map_df.drop(filter_indices, inplace=True)

    # Change CRS to equal area
    map_df = map_df.to_crs("EPSG:2163")
    
    return map_df

# Load NY Times COVID-19 data
def load_cases_data(filepath):
    cases_df = pd.read_csv(filepath)

    # Keep only first date from each county
    cases_df.drop_duplicates(subset=["county", "state"], keep="first", inplace=True)

    # Filter out unneeded columns
    cases_df.drop(columns=["fips", "cases", "deaths"], inplace=True)

    # Rename "count" label to match map_df
    cases_df.rename(columns={"county": "COUNTYNAME"}, inplace=True)

    # Add column for state abbreviations which matches column in map_df
    cases_df['STATE'] = cases_df.apply (lambda row: us.states.lookup(row['state']).abbr, axis=1)
    
    return cases_df

def make_merged_df(map_df, cases_df):
    # Merge map data with cases data
    merged_df = map_df.join(cases_df.set_index(["COUNTYNAME", "STATE"]), on=["COUNTYNAME", "STATE"])

    return merged_df

def make_plot(merged_df):
    #Set output dir path
    output_dir = "export/"
    
    # Set colormap (https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html)
    cmap = "gist_heat_r"

    # Get maximum value for normalization using current day
    d_min, d_max = 0, max(merged_df.apply (lambda row: days_since(row["date"], date.today()), axis=1))
    
    # Get range of dates
    start_date = datetime.strptime(sorted(list(filter(lambda d: isinstance(d, str), merged_df["date"].tolist())))[0], "%Y-%m-%d")
    date_range = list(pd.date_range(start=start_date, end=date.today().isoformat()))
    date_range = [date.fromisoformat(tstamp.strftime("%Y-%m-%d")) for tstamp in date_range]
    
    plot_files = []
    for end_date in date_range:
        # Set figure properties
        dpi = 400
        px, py = 1920, 1080
        w, h = px / dpi, py / dpi
        fig, ax = plt.subplots(1, 1, figsize=(w,h), dpi=dpi)

        #Set axis limits
        x_0, y_0 = 200000, -600000
        x_size, y_size = 4600000, 2900000
        buffer = 200000
        x_lim_left, x_lim_right = x_0 - (x_size/2) - buffer, x_0 + (x_size/2) + buffer
        y_lim_bottom, y_lim_top = y_0 - (y_size/2) - buffer, y_0 + (y_size/2) + buffer
        ax.set_xlim(x_lim_left, x_lim_right)
        ax.set_ylim(y_lim_bottom, y_lim_top)

        # Remove axis, add title
        ax.axis("off")
        plt.title("Days since first confirmed COVID-19 case", fontdict={"fontsize": "8", "fontweight": "4"}, y=0.9)

        # Create colorbar as a legend
        norm = plt.Normalize(vmin=d_min, vmax=d_max)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm._A = []
        # add the colorbar to the figure
        fig.colorbar(sm, shrink=0.5, pad=0.0)
        
        # Add date annotation
        ax.annotate(end_date.isoformat(), xy=(0.2, .225), xycoords="figure fraction", horizontalalignment="left", verticalalignment="top",fontsize=8)
        
        # Update column for days since first recorded infection
        merged_df["days_since_p0"] = merged_df.apply(lambda row: days_since(row["date"], end_date), axis=1)
        
        # Plot data
        merged_df.plot(column="days_since_p0", cmap=cmap, linewidth=0.05, ax=ax, edgecolor="0.5", norm=norm)
        
        # Save plot
        print("Saving plot ", str(end_date))
        fig_path = os.path.join(output_dir, end_date.strftime("%Y-%m-%d") + "_covid.jpg")
        fig.savefig(fig_path, dpi=dpi)
        plt.close(fig)
        plot_files.append(fig_path)
    
    # Sort image files by date
    plot_files = natsorted(plot_files)
    
    # Set GIF frame durations, add longer duration for last image
    durations = [0.3] * (len(plot_files) - 1)
    durations.append(5.0)
    
    # Create GIF
    with imageio.get_writer("covid-spread.gif", mode='I', duration=durations) as writer:
        for plot_file in plot_files:
            image = imageio.imread(plot_file)
            writer.append_data(image)
        
        
def main():
    shapefile = "shapefiles/c_02jn20.shp"
    cases_csv = "data/us-counties.csv"
    
    map_df = load_shapefile(shapefile)
    cases_df = load_cases_data(cases_csv)
    
    # Merge maps data and map data
    merged_df = make_merged_df(map_df, cases_df)
    
    make_plot(merged_df)

if __name__ == '__main__':
    main()

Saving plot  2020-01-21
Saving plot  2020-01-22
Saving plot  2020-01-23
Saving plot  2020-01-24
Saving plot  2020-01-25
Saving plot  2020-01-26
Saving plot  2020-01-27
Saving plot  2020-01-28
Saving plot  2020-01-29
Saving plot  2020-01-30
Saving plot  2020-01-31
Saving plot  2020-02-01
Saving plot  2020-02-02
Saving plot  2020-02-03
Saving plot  2020-02-04
Saving plot  2020-02-05
Saving plot  2020-02-06
Saving plot  2020-02-07
Saving plot  2020-02-08
Saving plot  2020-02-09
Saving plot  2020-02-10
Saving plot  2020-02-11
Saving plot  2020-02-12
Saving plot  2020-02-13
Saving plot  2020-02-14
Saving plot  2020-02-15
Saving plot  2020-02-16
Saving plot  2020-02-17
Saving plot  2020-02-18
Saving plot  2020-02-19
Saving plot  2020-02-20
Saving plot  2020-02-21
Saving plot  2020-02-22
Saving plot  2020-02-23
Saving plot  2020-02-24
Saving plot  2020-02-25
Saving plot  2020-02-26
Saving plot  2020-02-27
Saving plot  2020-02-28
Saving plot  2020-02-29
Saving plot  2020-03-01
Saving plot  202