In [1]:
import pandas as pd
from prophet import Prophet
from pymongo import MongoClient
from prophet.plot import plot_plotly, plot_components_plotly
import os
import time
import pickle
from datetime import datetime
import itertools
import dask
from dask.distributed import Client

Importing plotly failed. Interactive plots will not work.


In [2]:
df_parent = pd.read_csv('covid_parents_trained.csv')
df_parent.shape

(56, 4)

In [3]:
class TrainedParent:
    def __init__(self, gis_join, rmse, changepoint_prior_scale, seasonality_prior_scale):
        self.gis_join = gis_join
        self.rmse = rmse
        self.changepoint_prior_scale = changepoint_prior_scale
        self.seasonality_prior_scale = seasonality_prior_scale
        
    def __str__(self):
        return f'{self.gis_join}: (rmse={self.rmse}, changepoint_prior_scale={self.changepoint_prior_scale}, seasonality_prior_scale={self.seasonality_prior_scale})'
    
trained_parents_map = {}  
for i, row in df_parent.iterrows():
    gis_join = row['GISJOIN']
    rmse = row['rmse']
    changepoint_prior_scale = row['changepoint_prior_scale']
    seasonality_prior_scale = row['seasonality_prior_scale']
    trained_parents_map[gis_join] = (TrainedParent(gis_join, 
                                                   rmse, changepoint_prior_scale, seasonality_prior_scale))
    
print(len(trained_parents_map.keys()))

56


In [4]:
# Load Child Parent Map
child_parent_map = pickle.load(open('ucc-21/covid_child_parent_map.pkl', 'rb'))

In [5]:
# import numpy as np

# trained_parents_list = list(trained_parents_map.keys())
# loaded_parents_list = list(child_parent_map.values())
# difference = set(trained_parents_list) - set(loaded_parents_list)
# print(f'Difference: {len(difference)}')

# parents_npa = np.asarray(parents_list)

# unique_parents = np.unique(parents_npa)
# print(len(unique_parents))
# unique_parents

### Original Parents

G0400130, G0400190, G0400270, G0500590, G0500690, G0501030, G0600370, G0600590, G0600650, G0600710, G0600730, G0800050, G0800150, G0900090, G1200090, G1200110, G1200170, G1200830, G1201050, G1301350, G1700310, G1700430, G1701110, G1800970, G1900610, G1901010, G2405100, G2601210, G2700810, G2900950, G3000130, G3200030, G3300150, G3400030, G3500010, G3600510, G3600590, G3600910, G3900170, G3900350, G3900410, G4001090, G4200030, G4201010, G4600990, G4800270, G4800290, G4800850, G4801090, G4801130, G4801570, G4802010, G4802150, G4900350, G5100090, G5400550

In [6]:
df_clusters = pd.read_csv('~/ucc-21/clusters-covid.csv')
df_clusters.head()

Unnamed: 0.1,Unnamed: 0,GISJOIN,cluster_id,distance,is_parent,frac_distance,sample_percent
0,0,G0100010,39,7.582524,0.0,0.046117,0.059223
1,1,G0100030,37,21.277778,0.0,0.109459,0.071892
2,2,G0100050,47,22.647059,0.0,0.288432,0.107686
3,3,G0100070,22,53.160338,0.0,0.611449,0.17229
4,4,G0100090,29,55.71875,0.0,0.522091,0.154418


In [7]:
# child GISJOIN to sample_percent map
child_map = {}
for i, row in df_clusters.iterrows():
    is_parent = row['is_parent']
    sample_percent = row['sample_percent']
    if not is_parent and sample_percent <= 0.15:
        gis_join = row['GISJOIN']
        
        child_map[gis_join] = sample_percent
        
no_of_children = len(child_map.keys())
no_of_parents = len(trained_parents_map.keys())

# assert no_of_children == (df_clusters.shape[0] - no_of_parents)
print(no_of_children)

1817


In [8]:
db = MongoClient("lattice-100", 27018)
collection = 'covid_county_formatted'

def get_df_by_gis_join(gis_join, sample_percent=1.0):
    print(gis_join, end=' ')
    cursor = db.sustaindb[collection].aggregate([{"$match": {"GISJOIN": gis_join}}])
    df = pd.DataFrame(list(cursor))[['date', 'cases']]
    df.columns = ['ds','y']
    return df.sample(frac=sample_percent)

In [9]:
def predict_transfer(df_train, parent_trained):
    time1 = time.monotonic()
    # initilaize model with hyperparameters from parent model
    m = Prophet(
        seasonality_prior_scale = parent_trained.seasonality_prior_scale,
        changepoint_prior_scale = parent_trained.changepoint_prior_scale,
    )
    m.fit(df_train, algorithm='LBFGS')
    df_train_future = m.make_future_dataframe(periods=300, freq='H')
    df_train_forecast = m.predict(df_train_future)
    time2 = time.monotonic()

    return m, df_train_future, df_train_forecast, (time2 - time1)


def predict_transfer_task(df_train, gis_join, parent_trained):
    m, df_train_future, df_train_forecast, time_taken = predict_transfer(df_train, parent_trained)
    return gis_join, time_taken


# child_list = []
# child_dfs_list = []

# for gis_join, sample_percent in child_map.items():
#     child_list.append(gis_join)
#     child_dfs_list.append(get_df_by_gis_join(gis_join, sample_percent))
child_list = pickle.load(open('ucc-21/child_list_0_15.pkl', 'rb'))
child_dfs_list = pickle.load(open('ucc-21/child_dfs_list_0_15.pkl', 'rb'))

In [10]:
print(len(child_list))
print(len(child_dfs_list))

1817
1817


In [16]:
pickle.dump(child_list, open('ucc-21/temp_child_list.pkl', 'wb'))

In [11]:
# pickle.dump(child_list, open('ucc-21/child_list_0_15.pkl', 'wb'))
# pickle.dump(child_dfs_list, open('ucc-21/child_dfs_list_0_15.pkl', 'wb'))

In [13]:
client = Client('lattice-150:8786')

counter = 1
lazy_results = []
for gis_join, df_ in zip(child_list, child_dfs_list):
    try:
        parent = child_parent_map[gis_join]
        parent_trained_  = trained_parents_map[parent]
        print(parent_trained_)
        lazy_result = dask.delayed(predict_transfer_task)(df_, gis_join, parent_trained_)
        lazy_results.append(lazy_result)
    except:
        print(f'Error on {gis_join}')
    if counter % 100 == 0:
        print(counter, end=', ')
    counter += 1
#     break

futures = dask.persist(*lazy_results)  # trigger computation in the background
results = dask.compute(*futures)
results[:5]

Error on G0100010
Error on G0100030
Error on G0100050
Error on G0100110
Error on G0100150
Error on G0100290
Error on G0100310
Error on G0100350
Error on G0100370
Error on G0100430
Error on G0100470
Error on G0100490
Error on G0100510
Error on G0100550
Error on G0100570
Error on G0100610
Error on G0100630
Error on G0100690
Error on G0100710
Error on G0100750
Error on G0100770
Error on G0100790
Error on G0100830
Error on G0100890
Error on G0100910
Error on G0100930
G1701110: (rmse=56.54114220754974, changepoint_prior_scale=0.1, seasonality_prior_scale=0.01)
Error on G0101010
Error on G0101030
Error on G0101050
Error on G0101070
Error on G0101110
Error on G0101170
Error on G0101190
Error on G0101210
Error on G0101230
Error on G0101250
Error on G0101310
Error on G0200130
Error on G0200200
Error on G0200500
Error on G0200680
Error on G0200700
Error on G0200900
Error on G0201000
Error on G0201100
Error on G0201300
Error on G0201500
Error on G0201700
Error on G0201880
Error on G0201950
Error 

Error on G3801050
Error on G3900050
Error on G3900070
Error on G3900110
Error on G3900150
Error on G3900250
Error on G3900290
Error on G3900310
Error on G3900330
Error on G3900430
Error on G3900450
Error on G3900490
Error on G3900510
Error on G3900570
G1700430: (rmse=174.3024734636542, changepoint_prior_scale=0.01, seasonality_prior_scale=0.1)
Error on G3900630
Error on G3900710
Error on G3900730
Error on G3900770
Error on G3900790
Error on G3900810
Error on G3900870
Error on G3900910
Error on G3900930
Error on G3900950
Error on G3900970
Error on G3900990
Error on G3901010
Error on G3901050
Error on G3901070
1200, Error on G3901150
Error on G3901190
Error on G3901250
Error on G3901290
Error on G3901330
Error on G3901370
Error on G3901450
Error on G3901470
Error on G3901490
Error on G3901550
Error on G3901570
Error on G3901590
Error on G3901610
Error on G3901730
Error on G4000030
Error on G4000070
Error on G4000110
Error on G4000130
Error on G4000150
Error on G4000170
Error on G4000190


(('G0100970', 72.88429450499825),
 ('G0600190', 74.34747477294877),
 ('G0600470', 75.01832640799694),
 ('G0600750', 75.16081715305336),
 ('G0600850', 75.130009528948))

In [14]:
gis_joins = []
times = []

for r, t in results:
    gis_joins.append(r)
    times.append(t)
    
times_0_15_df = pd.DataFrame(zip(gis_joins, times), columns=['GISJOIN', 'time'])
# times_0_15_df.to_csv('ucc-21/child_training_tl_times_0_15.csv', index=False)
times_0_15_df

Unnamed: 0,GISJOIN,time
0,G0100970,72.884295
1,G0600190,74.347475
2,G0600470,75.018326
3,G0600750,75.160817
4,G0600850,75.13001
5,G1200210,74.434802
6,G1701190,74.984225
7,G2100670,75.101669
8,G2101110,75.236491
9,G2601250,75.215337
