In [183]:
from scipy.special import erf
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
import time

import pandas as pd
from util import util

In [184]:
# We import the sample_submission.csv file as a way of determining
# the order of the rows in out output file
sample_submission = pd.read_csv("../sample_submission.csv")

# The fips_key.csv file contains standard information about each county
key = pd.read_csv("../data/us/processing_data/fips_key.csv", encoding='latin-1')

# Daily deaths contains the death count per day for each county.
# Cumulative deaths contains the total death count for each county
# by day.
daily_deaths = pd.read_csv("../data/us/covid/nyt_us_counties_daily.csv")
cumulative_deaths = pd.read_csv("../data/us/covid/deaths.csv")
county_land_areas = pd.read_csv("../data/us/demographics/county_land_areas.csv", encoding='latin1')
county_population = pd.read_csv("../data/us/demographics/county_populations.csv", encoding='latin1')
mobility_data = pd.read_csv("../data/us/mobility/DL-us-m50.csv", encoding='latin1')

# List of all counties
all_fips = key["FIPS"].tolist()

util = util(daily_deaths, cumulative_deaths, county_land_areas, county_population, mobility_data, key)

In [185]:
def linear_curve(times, slope, intercept):
    return [x * slope for x in times] + intercept

def constant_curve(times, c):
    return [x * c for x in times]

In [186]:
def make_predictions(fips, startDate, endDate, n_past_steps, n_steps):
    # Use the daily deaths list to compute a list of the cumulative deaths.
    # This is better than directly accessing cumulative deaths because
    # the NY data is faulty, so we can directly replace the daily deaths
    # much more easily
    
    daily_deaths_list = util.get_deaths_list(fips, endDate=endDate)[-n_past_steps:]

    # Compute x and y lists to pass to curve_fit
    x = [i for i in range(n_past_steps)]
    y = daily_deaths_list
    
    if np.sum(daily_deaths_list) < 20:
        return [0] * n_steps
    
    x_input = [i + n_past_steps for i in range(n_steps)]
    popt, pcov = curve_fit(linear_curve, x, y, maxfev=10000)
    output = linear_curve(x_input, popt[0], popt[1])
    
    return output

def generate_quantiles(value):
    quantiles = []
    for i in range(-4, 5):
        quantiles.append(value + value * 0.1 * i)

    return quantiles

def get_id_list():
    return sample_submission["id"].values

def extract_fips_from_id(row_id):
    return row_id.split('-')[-1]

def extract_date_from_id(row_id):
    split = row_id.split('-')
    return '-'.join(split[:-1])

In [190]:
dates_of_interest = ["2020-05-10", "2020-05-11", "2020-05-12", "2020-05-13", "2020-05-14", "2020-05-15", "2020-05-16", \
                     "2020-05-17", "2020-05-18", "2020-05-19", "2020-05-20", "2020-05-1217", "2020-05-22", "2020-05-23", \
                     "2020-05-24", "2020-05-25", "2020-05-26", "2020-05-27"]

data = {}
for fips in all_fips:
    data[fips] = {}

    predictions = make_predictions(fips, "2020-03-30", "2020-05-09", 14, 18)

    for i, date in enumerate(dates_of_interest):
        quantiles = generate_quantiles(predictions[i])
        quantiles = [max(x, 0) for x in quantiles]
        data[fips][date] = quantiles

In [191]:
lists = []
for row_id in get_id_list():
    date = extract_date_from_id(row_id)
    fips = int(extract_fips_from_id(row_id))
    
    if not fips in data:
        lst = [row_id] + ["%.2f" % 0.00] * 9
        lists.append(lst)
        continue
        
    if not date in data[fips]:
        lst = [row_id] + ["%.2f" % 0.00] * 9
        lists.append(lst)
        continue
    
    quantiles = data[fips][date]
    lst = [row_id]

    for q in quantiles:
        if str(q) == "nan":
            lst.append("%.2f" % 0.00)
        elif q < 0:
            lst.append("%.2f" % 0.00)
        else:
            lst.append("%.2f" % q)
            
    lists.append(lst)
    
df = pd.DataFrame(lists, columns=sample_submission.columns)
df.to_csv("linear_fit_submission.csv", index=False, sep=',')

In [199]:
print(data[6037]["2020-05-25"])

[37.80659340663661, 44.10769230774271, 50.40879120884881, 56.70989010995491, 63.010989011061014, 69.31208791216712, 75.61318681327322, 81.91428571437932, 88.21538461548542]
