In [6]:
import pandas as pd
from prophet import Prophet
from pymongo import MongoClient
from prophet.plot import plot_plotly, plot_components_plotly
import os
import time
import pickle
import numpy as np
from datetime import datetime
import itertools
import dask
from dask.distributed import Client

In [7]:
df_parent = pd.read_csv('covid_parents_trained.csv')
df_parent.shape

(56, 4)

In [10]:
class TrainedParent:
    def __init__(self, gis_join, rmse, changepoint_prior_scale, seasonality_prior_scale):
        self.gis_join = gis_join
        self.rmse = rmse
        self.changepoint_prior_scale = changepoint_prior_scale
        self.seasonality_prior_scale = seasonality_prior_scale
        
    def __str__(self):
        return f'{self.gis_join}: (rmse={self.rmse}, changepoint_prior_scale={self.changepoint_prior_scale}, seasonality_prior_scale={self.seasonality_prior_scale})'
    
trained_parents_map = {}  
for i, row in df_parent.iterrows():
    gis_join = row['GISJOIN']
    rmse = row['rmse']
    changepoint_prior_scale = row['changepoint_prior_scale']
    seasonality_prior_scale = row['seasonality_prior_scale']
    trained_parents_map[gis_join] = (TrainedParent(gis_join, 
                                                   rmse, changepoint_prior_scale, seasonality_prior_scale))
    
print(len(trained_parents_map.keys()))

56


In [11]:
# Load Child Parent Map
parent_child_map = pickle.load(open('pickles/parent_child_map.pkl', 'rb'))

In [18]:
df_clusters = pd.read_csv('./clusters-covid.csv')
df_clusters.head()

Unnamed: 0.1,Unnamed: 0,GISJOIN,cluster_id,distance,is_parent,frac_distance,sample_percent
0,0,G0100010,28,2.045455,0.0,0.011364,0.052273
1,1,G0100030,8,47.464286,0.0,0.158088,0.081618
2,2,G0100050,51,5.167382,0.0,0.104167,0.070833
3,3,G0100070,51,16.832618,0.0,0.347192,0.119438
4,4,G0100090,16,97.492857,0.0,0.90566,0.231132


In [19]:
parents_pickle = set(parent_child_map.keys())
parents_csv = set(df_clusters[df_clusters['is_parent'] == 1]['GISJOIN'])

print(f'parents_pickle: {parents_pickle}')
print(f'parents_csv: {parents_csv}')

print(len(parents_pickle - parents_csv))

parents_pickle: {'G4201010', 'G0100790', 'G5300350', 'G3600290', 'G1300210', 'G0400130', 'G1800950', 'G3601190', 'G0600650', 'G1801030', 'G1200990', 'G3901510', 'G4804390', 'G0500330', 'G0500890', 'G4803750', 'G1200710', 'G1600810', 'G4500790', 'G4801570', 'G4800290', 'G3400290', 'G1901670', 'G3701190', 'G1600750', 'G4800270', 'G1200860', 'G4900490', 'G4701710', 'G2100670', 'G4600910', 'G4500450', 'G0800690', 'G1200570', 'G4802010', 'G3900490', 'G0600730', 'G2101110', 'G1301170', 'G4500310', 'G2700270', 'G3400070', 'G5500090', 'G0800410', 'G0600010', 'G4200710', 'G5300050', 'G3000310', 'G4200910', 'G3600590', 'G4900350', 'G0600370', 'G0600590', 'G4001090', 'G1700270', 'G3900350'}
parents_csv: {'G4201010', 'G0100790', 'G5300350', 'G3600290', 'G1300210', 'G0400130', 'G1800950', 'G1801030', 'G3601190', 'G1200990', 'G3901510', 'G4804390', 'G0500330', 'G0500890', 'G4803750', 'G1700270', 'G1200710', 'G1600810', 'G4500790', 'G3400290', 'G4800290', 'G4801570', 'G1901670', 'G3701190', 'G1600750

In [24]:
# child GISJOIN to sample_percent map
child_sample_perc_map = {}
for i, row in df_clusters.iterrows():
    is_parent = row['is_parent']
    sample_percent = row['sample_percent']
    if not is_parent and sample_percent <= 0.15:
        gis_join = row['GISJOIN']
        
        child_sample_perc_map[gis_join] = sample_percent
        
children_list = list(child_sample_perc_map.keys())
no_of_children = len(children_list)
no_of_parents = len(trained_parents_map.keys())

# assert no_of_children == (df_clusters.shape[0] - no_of_parents)
print(f'no_of_children: {no_of_children}')
print(f'no_of_parents: {no_of_parents}')

no_of_children: 1871
no_of_parents: 56



In [32]:
db = MongoClient("lattice-100", 27018)
collection = 'covid_county_formatted'

def get_df_by_gis_join(gis_join, sample_percent=1.0):
    print(gis_join, end=' ')
    cursor = db.sustaindb[collection].aggregate([{"$match": {"GISJOIN": gis_join}}])
    df = pd.DataFrame(list(cursor))[['date', 'cases']]
    df.columns = ['ds','y']
    return df.sample(frac=sample_percent)

In [33]:
def predict_transfer(df_train, parent_trained):
    time1 = time.monotonic()
    # initilaize model with hyperparameters from parent model
    m = Prophet(
        seasonality_prior_scale = parent_trained.seasonality_prior_scale,
        changepoint_prior_scale = parent_trained.changepoint_prior_scale,
    )
    m.fit(df_train, algorithm='LBFGS')
    df_train_future = m.make_future_dataframe(periods=300, freq='H')
    df_train_forecast = m.predict(df_train_future)
    time2 = time.monotonic()

    return m, df_train_future, df_train_forecast, (time2 - time1)


def predict_transfer_task(df_train, gis_join, parent_trained):
    m, df_train_future, df_train_forecast, time_taken = predict_transfer(df_train, parent_trained)
    return gis_join, time_taken


children_dfs_map = {}

for gis_join, sample_percent in child_sample_perc_map.items():
    children_dfs_map[gis_join] = get_df_by_gis_join(gis_join, sample_percent)

G0100010 G0100030 G0100050 G0100070 G0100110 G0100150 G0100210 G0100250 G0100290 G0100330 G0100350 G0100370 G0100390 G0100410 G0100430 G0100450 G0100490 G0100510 G0100530 G0100550 G0100570 G0100630 G0100690 G0100730 G0100770 G0100810 G0100870 G0100890 G0100930 G0100990 G0101010 G0101030 G0101070 G0101090 G0101110 G0101130 G0101170 G0101190 G0101210 G0101290 G0101330 G0200130 G0200160 G0200200 G0200680 G0200700 G0200900 G0201000 G0201100 G0201300 G0201500 G0201700 G0201880 G0201950 G0201980 G0202200 G0202300 G0202610 G0202750 G0202900 G0400010 G0400050 G0400110 G0400120 G0400210 G0400230 G0500010 G0500030 G0500110 G0500130 G0500150 G0500190 G0500210 G0500230 G0500250 G0500270 G0500350 G0500390 G0500410 G0500470 G0500510 G0500550 G0500610 G0500630 G0500670 G0500730 G0500750 G0500870 G0500930 G0500950 G0500970 G0500990 G0501030 G0501050 G0501070 G0501110 G0501130 G0501150 G0501170 G0501210 G0501250 G0501270 G0501290 G0501310 G0501330 G0501370 G0501390 G0501410 G0501450 G0501490 G0600030 G

G2901630 G2901670 G2901690 G2901710 G2901730 G2901750 G2901790 G2901810 G2901830 G2901850 G2901860 G2901970 G2901990 G2902030 G2902050 G2902070 G2902110 G2902150 G2902170 G2902190 G2902210 G2902250 G2902270 G2902290 G2905100 G3000010 G3000050 G3000070 G3000090 G3000110 G3000130 G3000150 G3000190 G3000250 G3000290 G3000330 G3000370 G3000390 G3000410 G3000430 G3000450 G3000510 G3000550 G3000570 G3000590 G3000610 G3000650 G3000690 G3000710 G3000730 G3000750 G3000770 G3000790 G3000810 G3000830 G3000850 G3000870 G3000890 G3000910 G3000950 G3000970 G3001010 G3001030 G3001050 G3001070 G3001090 G3100030 G3100050 G3100070 G3100090 G3100110 G3100150 G3100170 G3100210 G3100230 G3100270 G3100290 G3100310 G3100330 G3100370 G3100450 G3100490 G3100510 G3100530 G3100570 G3100590 G3100610 G3100630 G3100650 G3100670 G3100690 G3100710 G3100730 G3100750 G3100770 G3100790 G3100830 G3100850 G3100870 G3100890 G3100910 G3100970 G3101030 G3101050 G3101070 G3101110 G3101130 G3101150 G3101170 G3101190 G3101210 G

G5500870 G5500890 G5500910 G5500950 G5501010 G5501050 G5501070 G5501110 G5501130 G5501190 G5501210 G5501250 G5501270 G5501330 G5501390 G5501410 G5600010 G5600030 G5600050 G5600070 G5600110 G5600130 G5600150 G5600170 G5600190 G5600210 G5600230 G5600270 G5600290 G5600310 G5600330 G5600370 G5600390 G5600410 G5600430 G5600450 

In [34]:
pickle.dump(child_sample_perc_map, open('pickles/child_sample_perc_map.pkl', 'wb'))

print(len(children_list))
print(len(children_dfs_list))

1871
1871


In [37]:
client = Client('lattice-150:8786')

time1 = time()

counter = 1
lazy_results = []
for parent, children in parent_child_map.items():
    try:
        parent_trained_  = trained_parents_map[parent]
        print(parent_trained_)
        for child in children:
            # check if the child falls under < 0.15 sampling percentage
            if child in children_dfs_map:
                child_df = children_dfs_map[child]
                lazy_result = dask.delayed(predict_transfer_task)(child_df, gis_join, parent_trained_)
                lazy_results.append(lazy_result)
    except Exception as e:
        print(f'Error on {gis_join}')
        print(f"{type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
    if counter % 100 == 0:
        print(counter, end=', ')
    counter += 1
#     break

futures = dask.persist(*lazy_results)  # trigger computation in the background
results = dask.compute(*futures)
results[:5]

time2 = time()
print(f'Time taken (dataset=COVID-19, childModels 0-15%: {time2 - time1}')

G2700270: (rmse=1.1558735324355147, changepoint_prior_scale=0.5, seasonality_prior_scale=10.0)
G0400130: (rmse=316.3039434010886, changepoint_prior_scale=0.5, seasonality_prior_scale=0.1)
G4900490: (rmse=445.9742756900704, changepoint_prior_scale=0.5, seasonality_prior_scale=0.1)
G0600370: (rmse=711.9078978906454, changepoint_prior_scale=0.1, seasonality_prior_scale=10.0)
G4800290: (rmse=331.0911518720875, changepoint_prior_scale=0.5, seasonality_prior_scale=0.01)
G1600750: (rmse=16.354557478885457, changepoint_prior_scale=0.5, seasonality_prior_scale=1.0)
G3400070: (rmse=43.5943167831879, changepoint_prior_scale=0.5, seasonality_prior_scale=0.1)
G0600650: (rmse=217.79810724772184, changepoint_prior_scale=0.1, seasonality_prior_scale=1.0)
G4800270: (rmse=51.60909367232097, changepoint_prior_scale=0.1, seasonality_prior_scale=10.0)
G2101110: (rmse=41.75339026097891, changepoint_prior_scale=0.1, seasonality_prior_scale=10.0)
G4802010: (rmse=575.6233054560781, changepoint_prior_scale=0.5,

(('G5600450', 69.01250222604722),
 ('G5600450', 75.13486853707582),
 ('G5600450', 73.20243107806891),
 ('G5600450', 66.80214228620753),
 ('G5600450', 65.58361178310588))

In [39]:
gis_joins = []
times = []

for r, t in results:
    gis_joins.append(r)
    times.append(t)
    
times_0_15_df = pd.DataFrame(zip(gis_joins, times), columns=['GISJOIN', 'time'])
times_0_15_df.to_csv('covid-19-child_training_tl_times_0_15.csv', index=False)
times_0_15_df

Unnamed: 0,GISJOIN,time
0,G5600450,69.012502
1,G5600450,75.134869
2,G5600450,73.202431
3,G5600450,66.802142
4,G5600450,65.583612
...,...,...
1866,G5600450,68.049765
1867,G5600450,70.349388
1868,G5600450,71.403052
1869,G5600450,73.929558
