In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import requests
from dotenv import load_dotenv

load_dotenv()

In [None]:
DATA_DIR = 'nyt-data'


def get_top_3_by_latest_cases(df, group):
    return df.sort_values(by=['date', 'cases'], ascending=[False, False]).head(3)[group].tolist()


def calculate_daily_stats(df, group = None):
    cases = df.groupby(group)['cases'] if group else df['cases']
    deaths = df.groupby(group)['deaths'] if group else df['deaths']
    df['daily_cases'] = cases.diff().fillna(0)
    df['daily_deaths'] = deaths.diff().fillna(0)
    return df


country_df = calculate_daily_stats(
    pd.concat(
        [
            pd.read_csv(os.path.join(DATA_DIR, 'us.csv')),
            pd.read_csv(os.path.join(DATA_DIR, 'live', 'us.csv')),
        ],
        join='inner',
    ).set_index('date'),
)

states_df = pd.concat(
    [
        pd.read_csv(os.path.join(DATA_DIR, 'us-states.csv')),
        pd.read_csv(os.path.join(DATA_DIR, 'live', 'us-states.csv')),
    ],
    join='inner',
)
top_3_states = get_top_3_by_latest_cases(states_df, 'state')
top_3_states_df = calculate_daily_stats(
    states_df[states_df['state'].isin(top_3_states)].set_index('date'),
    'state',
)

counties_df = pd.concat(
    [
        pd.read_csv(os.path.join(DATA_DIR, 'us-counties.csv')),
        pd.read_csv(os.path.join(DATA_DIR, 'live', 'us-counties.csv')),
    ],
    join='inner',
)
ca_counties_df = counties_df[counties_df['state'] == 'California']
top_3_ca_counties = get_top_3_by_latest_cases(ca_counties_df, 'county')
top_3_ca_counties_df = calculate_daily_stats(
    ca_counties_df[ca_counties_df['county'].isin(top_3_ca_counties)].set_index('date'),
    'county',
)

print(country_df.tail(1))
print(top_3_states_df.tail(3))
print(top_3_ca_counties_df.tail(3))

In [None]:
SIZE = (20, 24)

fig, axs = plt.subplots(7, 2, figsize=SIZE)


def draw_daily_graphs(ax_row, df, region_label):
    ax_row[0].set(xticks=[], title=f'Daily Cases {region_label}')
    ax_row[0].bar(df.index, df['daily_cases'])
    ax_row[0].plot(df.index, df['daily_cases'].rolling(7, 1).mean(), color='red')
    
    ax_row[1].set(xticks=[], title=f'Daily Deaths {region_label}')
    ax_row[1].bar(df.index, df['daily_deaths'])
    ax_row[1].plot(df.index, df['daily_deaths'].rolling(7, 1).mean(), color='red')


draw_daily_graphs(axs[0], country_df, 'U.S.')

for i, row in enumerate(range(1, 4)):
    draw_daily_graphs(
        axs[row],
        top_3_states_df[top_3_states_df['state'] == top_3_states[i]],
        top_3_states[i],
    )

for i, row in enumerate(range(4, 7)):
    draw_daily_graphs(
        axs[row],
        top_3_ca_counties_df[top_3_ca_counties_df['county'] == top_3_ca_counties[i]],
        top_3_ca_counties[i],
    )

plt.show()

In [None]:
def get_population(fips):
    fips_str = str(fips)
    if not fips:
        query = { 'for': 'us:1' }
    elif len(fips_str) > 3:
        state_fips = int(fips_str[0])
        county_fips = int(fips_str[1:])
        query = { 'for': f'county:{county_fips:03}', 'in': f'state:{state_fips:02}' }
    else:
        query = { 'for': f'state:{fips:02}' }
    
    r = requests.get(
        'https://api.census.gov/data/2019/pep/population',
        {
            **{
                'key': os.environ['CENSUS_API_KEY'],
                'get': 'POP',
            },
            **query,
        }
    )
    r.raise_for_status()
    return int(r.json()[1][0])


def calculate_per_million_stats(df):
    population = get_population(int(df['fips'][0]) if 'fips' in df.columns else None)
    per_million_df = pd.DataFrame()
    per_million_df['cases'] = 1_000_000 * df['cases'] / population
    per_million_df['deaths'] = 1_000_000 * df['deaths'] / population
    per_million_df['daily_cases_7_day_avg'] = 1_000_000 * df['daily_cases'].rolling(7, 1).mean() / population
    per_million_df['daily_deaths_7_day_avg'] = 1_000_000 * df['daily_deaths'].rolling(7, 1).mean() / population
    return per_million_df


def draw_per_region_graphs(ax, country_s, state_s, county_s, title):
    ax.set(xticks=[], title=title)
    ax.plot(country_s.index, country_s, color='red', label='U.S.')
    ax.plot(state_s.index, state_s, color='blue', label='California')
    ax.plot(county_s.index, county_s, color='green', label='Orange')
    ax.legend()


country_per_million_df = calculate_per_million_stats(country_df)
state_per_million_df = calculate_per_million_stats(top_3_states_df[top_3_states_df['state'] == 'California'])
county_per_million_df = calculate_per_million_stats(top_3_ca_counties_df[top_3_ca_counties_df['county'] == 'Orange'])

SIZE = (20, 8)

fig, axs = plt.subplots(2, 2, figsize=SIZE)

draw_per_region_graphs(
    axs[0][0],
    country_per_million_df['cases'],
    state_per_million_df['cases'],
    county_per_million_df['cases'],
    'Cases per Million',
)
draw_per_region_graphs(
    axs[0][1],
    country_per_million_df['deaths'],
    state_per_million_df['deaths'],
    county_per_million_df['deaths'],
    'Deaths per Million',
)
draw_per_region_graphs(
    axs[1][0],
    country_per_million_df['daily_cases_7_day_avg'],
    state_per_million_df['daily_cases_7_day_avg'],
    county_per_million_df['daily_cases_7_day_avg'],
    'Daily Cases (7 Day Avg) per Million',
)
draw_per_region_graphs(
    axs[1][1],
    country_per_million_df['daily_deaths_7_day_avg'],
    state_per_million_df['daily_deaths_7_day_avg'],
    county_per_million_df['daily_deaths_7_day_avg'],
    'Daily Deaths (7 Day Avg) per Million',
)

plt.show()