<h1>Flu Onset Prediction</h1>
Table of contents:
<ol>
    <li><a href="#imputation">Imputation</a></li>
    <li><a href="#time_to_onset">Time to Onset</a></li>
</ol>

In [None]:
%load_ext autoreload

In [None]:
%autoreload
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler

import pandas as pd
import numpy as np

import os
import glob

from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm_pandas

import json
import pickle
from dataclasses import dataclass, asdict
from copy import deepcopy

from IPython.display import clear_output

from constants import *
from run_model import apply_limits

In [None]:
# replace argparse args in notebook

@dataclass
class Args():
    data_dir='/datasets/evidationdata'
    split_seed=0
    regularly_sampled=True
    all_survey=False
    min_date='2019-10-01' #None or format 'yyyy-mm-dd'
    max_date='2020-08-14' #None or format 'yyyy-mm-dd'
    # This argument is now required for the GET_PATH_DICT_KEY function, for the dataloader 
    # it should always be False the fake_data is made in a separate file given the correct base data.
    fake_data=False
    # if from_src then the data will be loaded from the original files from Raghu, if not from_src, 
    # then it will reload from the the final csv's if they exist.
    from_src=True
    # wave, is an integer indicating which set of data from Raghu to load from, currently they are just wave1 and wave2.
    # the integers correspond to entries in a dictionary which contains the paths for a given wave in the function load_data.
    wave=3
    

args=Args()


In [None]:
def merge_helper(sub_group):
    
    value = sub_group.aggregate({'date_onset_merged':'min', 'date_recovery_merged': 'max'})[['date_onset_merged', 'date_recovery_merged']]

    res = sub_group
    
    res['new_date_onset_merged']=pd.to_datetime(value['date_onset_merged'])
    res['new_date_recovery_merged']=pd.to_datetime(value['date_recovery_merged'])

    return res
    
def merge_close(participant):
    
    test = (participant.groupby((participant['date_onset_merged'] - participant['date_recovery_merged'].shift() > pd.to_timedelta(7, unit='days')).cumsum())).apply(merge_helper)
    
    merge_close.num_calls += 1
    if merge_close.num_calls % 1000 == 0:
        print(merge_close.num_calls)
    
    return test
merge_close.num_calls = 0

def load_data(args):
    """
    load the data from the data dir.
    inputs:
        args (argparse.ArgParser): arguments containing the data_dir
    returns:
        (dict): a dictionary of pandas.DataFrames
    """
    wave_dict = {1: {'numeric': 'wave1/fbml_day_level_features.feather', 
                     'baseline': 'wave1/baseline_health_data.csv',
                     'survey': 'wave1/survey_data.csv'},
                 2: {'numeric': 'wave2/day_level_activity_features.feather', 
                     'baseline': 'wave2/demographic_health_data.csv',
                     'survey': 'wave2/survey_data.csv'},
                 3: {'numeric': 'wave2/day_level_activity_features.feather', 
                     'baseline': 'wave3/demographics.csv',
                     'survey': 'wave3/survey.csv'},}
    
    dfs={}
    
    #numeric features
    if os.path.exists(os.path.join(args.data_dir, 'activity_data.csv')) and not args.from_src:
        out_df = pd.read_csv(os.path.join(args.data_dir, 'activity_data.csv'))
    else:
        out_df = pd.read_feather(os.path.join(args.data_dir, wave_dict[args.wave]['numeric']))
        out_df['date'] = pd.to_datetime(out_df['date'])
        print(out_df.columns.tolist())
        if args.wave == 1:
            out_df=out_df.loc[out_df['date']<= '2020-04-01']
            out_df2 = pd.read_feather(os.path.join(args.data_dir, 'wave1/fbml_day_level_features-2020-04-01_to_2020-04-17.feather'))
            out_df2['date'] = pd.to_datetime(out_df2['date'])
            print(out_df2.columns.tolist())
            out_df=out_df.append(out_df2, ignore_index=True).drop_duplicates(keep='last')
    out_df = out_df.drop([s for s in out_df.columns.to_list() if 'sleep__main_start_time' in s ][0], axis = 1)
    out_df['date'] = pd.to_datetime(out_df['date'])
    if args.min_date is not None:
        out_df=out_df.loc[out_df['date']>=args.min_date]
    if args.max_date is not None:
        out_df=out_df.loc[out_df['date']<= args.max_date]
    out_df.set_index(['participant_id', 'date'], inplace=True)
    out_df=out_df.groupby(['participant_id', 'date']).max()
    # if all data in a row is 0, other than the two sleep features, then we should drop those data points.
    zero_default_cols= ['sleep__sleep__total_asleep_minutes', 'sleep__sleep__total_in_bed_minutes']
    # comment this out because we fixed the consent ranges for participants.
#     print('before drop: ', len(out_df))
    out_df[zero_default_cols] = out_df[zero_default_cols].replace({0:np.nan})
#     out_df_no_missingness_index = out_df[QC_VARIABLES].dropna(how='all').index
#     out_df = out_df.loc[out_df_no_missingness_index, :]
#     print('after drop: ', len(out_df))
    
    display(out_df.head())
    participant_set=out_df.index.get_level_values('participant_id')
    dfs['activity'] = out_df.copy()
    print('Done numeric features')
    
    #static features
    if os.path.exists(os.path.join(args.data_dir, 'baseline_health_data.csv')) and not args.from_src:
        out_df = pd.read_csv(os.path.join(args.data_dir, 'baseline_health_data.csv'), index_col='participant_id')
    else: 
        out_df = pd.read_csv(os.path.join(args.data_dir, wave_dict[args.wave]['baseline']))
        # the participant_id column is called user_id in wave2, change it for consistency.
        out_df.columns = out_df.columns.str.replace('user_id', 'participant_id')
        out_df = out_df.set_index('participant_id')
    out_df=out_df.loc[list(set(participant_set).intersection(set(out_df.index.tolist()))),:]
    out_df = one_hot_encode_cols(out_df, BASELINE_CATEGORICAL_COLUMNS)        
    out_df[BASELINE_ZEROFILL] = out_df[BASELINE_ZEROFILL].fillna(0)
    out_df[BASELINE_MEDIANFILL] = out_df[BASELINE_MEDIANFILL].fillna(out_df[BASELINE_MEDIANFILL].median())
    out_df[BASELINE_COLS] = out_df[BASELINE_COLS].apply(pd.to_numeric)
    dfs['baseline']=out_df[BASELINE_COLS+['state']].copy()
    print('Done static features')
    
    
    # labels_dataframe
    if os.path.exists(os.path.join(args.data_dir, 'survey_data.csv')) and not args.from_src:
        out_df = pd.read_csv(os.path.join(args.data_dir, 'survey_data.csv'))
    else:
        out_df = pd.read_csv(os.path.join(args.data_dir, wave_dict[args.wave]['survey']))
    
    out_df['date_survey'] = pd.to_datetime(out_df['date_survey'])
    out_df['date_onset_merged'] = pd.to_datetime(out_df['date_onset_merged'])
    out_df['date_recovery_merged'] = pd.to_datetime(out_df['date_recovery_merged'])
    
    out_df.set_index(['participant_id', 'date_survey'], inplace=True)
    
#     out_df=out_df.iloc[:50000] # set this for debug
    
    # Keep only survey data for participants with activity data as well.
    if not(args.all_survey):
        merged_index=list(set(participant_set).intersection(set(out_df.index.get_level_values('participant_id'))))
        out_df=out_df.loc[out_df.index.get_level_values('participant_id').isin(merged_index),:]
    
    # Merge events where the previous date of recovery is only 7 days away from the next date of onset
    out_df.reset_index(inplace=True)
    out_df = out_df.groupby('participant_id').apply(merge_close) 
    out_df['date_onset_merged']=out_df['new_date_onset_merged'].values
    out_df['date_recovery_merged']=out_df['new_date_recovery_merged'].values
    
    out_df.set_index(['participant_id', 'date_survey'], inplace=True)
    
    # create test set participants from the set of participants in the survey, this won't remove them yet, 
    # just not calculate their event dates.
#     test_participants = out_df.groupby('participant_id').agg({'date_onset_merged': 'min'})
#     test_participants = np.unique(test_participants[test_participants['date_onset_merged'] > '2020-06-29'].index.get_level_values('participant_id'))
#     print('Number of participants in the test set: ' + str(len(test_participants)))
#     train_participants = [p for p in np.unique(out_df.index.get_level_values('participant_id')) if p not in test_participants]
#     out_df = out_df.loc[train_participants]
#     dfs['train_participants'] = train_participants
#     dfs['test_participants'] = test_participants
    out_df.fillna(0, inplace=True)
    
    dfs['survey']= prepare_labels_df(out_df)
    
    # setting a min_date also helps to remove rows which are introduced because of far earlier health problems, for example:
    # 
    if args.min_date is not None:
        dfs['survey']=dfs['survey'].loc[dfs['survey'].index.get_level_values('date_survey')>=args.min_date]
    if args.max_date is not None:
        dfs['survey']=dfs['survey'].loc[dfs['survey'].index.get_level_values('date_survey')<=args.max_date]
    print('Done labels')
    
#     print(set(dfs['survey'].index.tolist()))
#     print(set(dfs['activity'].index.tolist()))

    print('before', len(dfs['survey'].index), len(dfs['activity'].index), )
    
    #join the indices of everything
    if not(args.all_survey):
        dfs['survey'] = dfs['survey'].reindex(dfs['survey'].index.union(dfs['activity'].index).drop_duplicates())
        dfs['survey'].index.names = ['participant_id', 'date']
        dfs['activity'] = dfs['activity'].reindex(dfs['activity'].index.union(dfs['survey'].index).drop_duplicates())
        dfs['activity'].index.names = ['participant_id', 'date']
    else:
        # activity should only be the activity where the survey is OK. (intersection, not union)
        # survey index in activity participants
        dfs['survey'].index.names = ['participant_id', 'date']
        survey_activity_intersection = dfs['survey'].loc[dfs['survey'].index.get_level_values('participant_id').isin(set(dfs['activity'].index.get_level_values('participant_id')))].index
        dfs['activity'] = dfs['activity'].reindex(dfs['activity'].index.union(survey_activity_intersection).drop_duplicates())
        dfs['activity'].index.names = ['participant_id', 'date']
    
    print('after', len(dfs['survey'].index), len(dfs['activity'].index), )
    
    # Remove the participants from the activity and survey data so that only the training set of participants is in the data.
#     dfs['survey'] = dfs['survey'].loc[train_participants]
#     dfs['activity'] = dfs['activity'].loc[train_participants]
    
#     print(set(dfs['survey'].index.tolist()))
#     print(set(dfs['activity'].index.tolist()))
    
    if args.regularly_sampled==True:
        # for models that don't handle missingness.
        min_inds= dfs['activity'].reset_index('date').groupby('participant_id').min().set_index(['date'], append=True).index
        max_inds= dfs['activity'].reset_index('date').groupby('participant_id').max().set_index(['date'], append=True).index

        person, min_i = zip(*sorted(min_inds.tolist()))
        person, max_i = zip(*sorted(max_inds.tolist()))
        all_inds=zip(person, min_i, max_i)

        result_inds=[]
        for ind in all_inds:
            dates = pd.date_range(start=ind[1], end=ind[2])
            result_inds+=list(zip([ind[0]]* len(dates), list(dates)))
        result_inds = pd.MultiIndex.from_tuples(result_inds, names = ['participant_id', 'date'])        
        dfs['survey'] = dfs['survey'].reindex(result_inds)
        dfs['activity'] = dfs['activity'].reindex(result_inds)
    
    dfs['activity']['weekday']=(dfs['activity'].index.get_level_values('date').weekday<5).astype(np.int32) # add week of year

    print('Done Participant Join')
    
    # now the labels must be filled in
    
    #get the index of the first date reported in the dataset for each participant
    tmp = dfs['survey'].reset_index()[['participant_id', 'date']].groupby('participant_id').min()
    tmp_index = tmp.set_index(['date'], append=True, drop=False).index
    
    #columns that are normally 0  
    zero_start_cols=[col for col in dfs['survey'].columns if col not in EXCLUDE_ZERO_FILL_LABEL_COLS]
    dfs['survey'].loc[tmp_index, zero_start_cols] = dfs['survey'].loc[tmp_index, zero_start_cols].fillna(0)
    
    # columns that are normally 1 (covid_no_sym)
    one_start_cols = ['covid__symptoms__none', 'symptoms__no_symptoms', 'covid__behavior__social_distancing__did_not']
    dfs['survey'].loc[tmp_index, one_start_cols] = dfs['survey'].loc[tmp_index, one_start_cols].fillna(1)
    
    # this tmp index is just for plotting df 
    tmp = dfs['survey'].reset_index()[['participant_id', 'date']].groupby('participant_id').apply(lambda x:x.nsmallest(2, 'date')).reset_index(drop=True)
    tmp_index = tmp.set_index(['participant_id', 'date'], append=False, drop=False).index
    tmp = dfs['survey'].reset_index()[['participant_id', 'date']].groupby('participant_id').max()
    tmp_index = tmp_index.union(tmp.set_index(['date'], append=True, drop=False).index)
    
    display(dfs['survey'].loc[tmp_index, :])
    
    print("forward filling...")
    # impute the zero start and one-start cols by forward filling
#     dfs['survey'].loc[:, zero_start_cols+one_start_cols] = dfs['survey'].loc[:, zero_start_cols+one_start_cols].bfill()
    dfs['survey'].loc[:, zero_start_cols+one_start_cols] = dfs['survey'].loc[:, zero_start_cols+one_start_cols].ffill()
    display(dfs['survey'].loc[tmp_index, :])
    
    if args.wave==2:
        print('Enforcing limits')
        limit_df = pd.read_csv(os.path.join(args.data_dir, 'wave2', 'activity_start_end_date.csv'))
        dfs = apply_limits(dfs, limit_df) # apply limits is from run_model
    
    # backward filling won't fill the zeros for the healthy days of the last participant
#     dfs['survey'][['ili', 'ili_24', 'ili_48', 'covid', 'covid_24', 'covid_48']] = dfs['survey'][['ili', 'ili_24', 'ili_48', 'covid', 'covid_24', 'covid_48']].fillna(0) 
#     display(dfs['survey'].loc[tmp_index, :])
    
    print("Done forward filling!")
    
    # Make a two class label, 0=healthy, 1=flu 2=covid
    # Commented out lines was a check for if there are overlapping medical__diagnosed and covid.
#     flu_and_covid = dfs['survey']['medical__diagnosed'] + dfs['survey']['covid']
#     assert len(flu_and_covid[flu_and_covid == 2]) == 0, 'found overlap of flu diagnosis and covid diagnosis labels'
    # Sum the two together giving 1 if medical__diagnosed ili 2 if covid, 3 if both.
    dfs['survey']['flu_covid'] = ((dfs['survey']['medical__diagnosed'] == 1)&(dfs['survey']['ili'] == 1)) + 2*dfs['survey']['covid'] + 2*dfs['survey']['covid']
    # Since there is overlap between medical__diagnosed and covid, if the sum is greater than two set it to two 
    # indicating covid.
    tmp = dfs['survey'][['flu_covid']]
    tmp['two'] = 2
    dfs['survey']['flu_covid'] = tmp.min(axis=1)
        
    return dfs




def prepare_labels_df(input_df):
    """
    Convert *d_ago into timeseries
    Convert complication onset into timeseries
    inputs:
        input_df (pd.DataFrame)
    returns:
        pd.DataFrame
    """
    label_df = input_df.copy()

#     label_df.columns = [col if '0d_ago' not in col else col.replace('__0d_ago', '') for col in label_df.columns]
    for col in label_df.columns:
        if '0d_ago' in col:
            print(col)
            if col.replace('__0d_ago', '') not in label_df.columns:
                label_df[col.replace('__0d_ago', '')] = label_df[col]
    
    label_df['time_to_survey']=0
    
    label_df = label_df.loc[:, ~label_df.columns.duplicated()]
    
    # reconcile multilabel targets
    
    # create covid_cohort category.
    idx=pd.IndexSlice
    label_df.loc[:, 'covid_cohort']=0
    label_df['covid__diagnosed'] = label_df['covid__diagnosed'].str.lower()
    covid_tested = set(label_df.loc[label_df['covid__diagnosed'].isin(['yes', 'no', 'i am waiting for my diagnosis']), :].index.get_level_values('participant_id'))
#     print('yes: ', len(set(label_df.loc[label_df['covid__diagnosed'].isin(['yes']), :].index.get_level_values('participant_id'))))
#     print('no: ', len(set(label_df.loc[label_df['covid__diagnosed'].isin(['no']), :].index.get_level_values('participant_id'))))
#     print('waiting: ',len(set(label_df.loc[label_df['covid__diagnosed'].isin(['i am waiting for my diagnosis']), :].index.get_level_values('participant_id'))))

    label_df.loc[idx[covid_tested, :], 'covid_cohort']=1    
    
    #assert len(covid_tested)<=177+1324+192 # from the survey data as of May 13
    #pd.get_dummies one hot encode covid_diagnosed with no nan col
    label_df=pd.get_dummies(label_df, columns=['covid__diagnosed']
                           ).join(label_df['covid__diagnosed'])
    label_df['covid__diagnosed'] = label_df['covid__diagnosed'].replace({'yes':1, 'no': 0, 'i am waiting for my diagnosis':np.nan,  'i don t know i can t remember':np.nan})
    label_df['covid__diagnosed'] = label_df['covid__diagnosed'].apply(lambda x: np.nan if isinstance(x, str) else x)
    label_df['covid_diagnosed_date'] = label_df['covid__diagnosed'].fillna(np.nan)
    
    med_diag = label_df['medical__diagnosed'].astype(str).str.lower()
    med_responses = np.unique(med_diag)
    replacement_dict = {}
    for r in med_responses:
        replacement_dict[r] = 1 if r == 'yes' else 0
    label_df['medical__diagnosed'] = med_diag.replace(replacement_dict)
#     label_df.loc[~label_df['medical__diagnosed'].isna(), 'medical__diagnosed']=(label_df.loc[~label_df['medical__diagnosed'].isna(), 'medical__diagnosed']=='yes').astype(np.int32)
    
        
        
    # clean up the categories that don't binarise nicely
    label_df['home__household_members_had_flu'] = pd.to_numeric(label_df['home__household_members_had_flu'].str.lower().replace({'yes':1}), errors='coerce').fillna(0)
    label_df['home__household_members_n'] = pd.to_numeric(label_df['home__household_members_n'].str.lower().replace({'>10':11}), errors='coerce').fillna(0)
    label_df['medical__hospitalized'] = pd.to_numeric(label_df['medical__hospitalized'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['medical__medication'] = pd.to_numeric(label_df['medical__medication'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['medical__medication_otc'] = pd.to_numeric(label_df['medical__medication_otc'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['medical__sought_attention'] = pd.to_numeric(label_df['medical__sought_attention'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['covid__quarantine'] = pd.to_numeric(label_df['covid__quarantine'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['medical__vaccinated_last_year'] = pd.to_numeric(label_df['medical__vaccinated_last_year'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)
    label_df['medical__vaccinated_this_year'] = pd.to_numeric(label_df['medical__vaccinated_this_year'].str.lower().replace({'yes':1, 'no':0}), errors='coerce').fillna(0)


    
    
    # replace strings with binary indicators
    print("BEFORE:")
    print(label_df.groupby('covid__any_household_diagnosed').count())
    BINARY_COLS2=[col for col in BINARY_COLS if col!='covid__diagnosed']
    label_df[BINARY_COLS2] = label_df[BINARY_COLS2].notnull().astype('int')    
    print("AFTER:")
    print(label_df.groupby('covid__any_household_diagnosed').count())
    
    
    # expand index
    all_inds = label_df.index.tolist()
    for offset in range(0, 7):
        # fill in dates for the missing days of the week
        additional_index = [(item[0], item[1]-pd.to_timedelta(offset, unit='d')) for item in  set(input_df.index.tolist())]
        original_inds=deepcopy(all_inds)
        all_inds += additional_index
        
        new_index = pd.MultiIndex.from_tuples(list(set(all_inds)), names=label_df.index.names).drop_duplicates()
        label_df = label_df.reindex(new_index)
        label_df['time_to_survey'] = label_df['time_to_survey'].fillna(offset) # add the time to survey here for reliability's sake

        label_df = label_df.sort_index()
        
        # bfill columns that have {offset}d_ago in them
        bfill_columns = [col for col in label_df.columns if f"__{offset}d_ago" in col] # this is why I go one row at a time
        
        label_df[bfill_columns]=label_df[bfill_columns].fillna(method='bfill', limit=1)


        # fill data with a dict of series using the *d_ago values
        replace_vals={col.replace(f"__{offset}d_ago", ""):label_df[col] for col in label_df.columns if f"__{offset}d_ago" in col}
        label_df = label_df.fillna(replace_vals)
        
    
        
    # now convert the complication columns to timeseries by adding them to the index and one-hot encoding
    
    for col in label_df.columns:
        if '__onset' not in col:
            continue
        label_df[col]= pd.to_datetime(label_df[col], errors='coerce')
        tups=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),label_df.loc[~label_df[col].isna(), col].values))
        
        # join these new dates into the index.
        new_index = pd.MultiIndex.from_tuples(list(set(label_df.index.tolist()+tups)), names=label_df.index.names).drop_duplicates()
        label_df = label_df.reindex(new_index)
        
        label_df.loc[:, col.replace('__onset', '')]=0
        label_df.loc[tups, col.replace('__onset', '')]=1
        
        
    # now convert the date columns to timeseries by adding them to the index and one-hot encoding
    """
    ILI Labels
    we will follow the following procedure:
    1) Get the day after date_recovery_merged
    2) Set these values to 0 for ILI
    3) Get the index of date_onset_merged
    4) Set these values to 1 for ILI
    """
    label_df.loc[:, 'ili']=np.nan
    label_df.loc[:, 'ili_24']=np.nan
    label_df.loc[:, 'ili_48']=np.nan
    
    
    # **************************
    # 1)
    
    col = 'date_recovery_merged'
    label_df[col]= pd.to_datetime(label_df[col], errors='coerce')+pd.to_timedelta(1, unit='D')
    tups=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),label_df.loc[~label_df[col].isna(), col].values))
    col_mod='date_onset_merged' # this will be 0 for ili_24 and ili_48
    tups_24=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),label_df.loc[~label_df[col].isna(), col_mod].values))
    tups_48=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),label_df.loc[~label_df[col].isna(), col_mod].values))

    # join these new dates into the index.
    new_index = pd.MultiIndex.from_tuples(list(set(label_df.index.tolist()+tups+tups_24+tups_48)), names=label_df.index.names).drop_duplicates()
    label_df = label_df.reindex(new_index)
    
    # **************************
    # 2) the end of each label is being set to zero

    label_df.loc[tups, 'ili']=0 if 'recovery' in col else 1
    label_df.loc[tups_24, 'ili_24']=0 if 'recovery' in col else 1
    label_df.loc[tups_48, 'ili_48']=0 if 'recovery' in col else 1
    
    # **************************
    # 3)
    
    col = 'date_onset_merged'
    label_df[col]= pd.to_datetime(label_df[col], errors='coerce')
    tups=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),label_df.loc[~label_df[col].isna(), col].values))
    tups_24=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),(label_df.loc[~label_df[col].isna(), col]-pd.to_timedelta(1, unit='D')).values))
    tups_48=list(zip(label_df.loc[~label_df[col].isna(), :].index.get_level_values('participant_id'),(label_df.loc[~label_df[col].isna(), col]-pd.to_timedelta(2, unit='D')).values))

    new_index = pd.MultiIndex.from_tuples(list(set(label_df.index.tolist()+tups+tups_24+tups_48)), names=label_df.index.names).drop_duplicates()
    label_df = label_df.reindex(new_index)
    
    # **************************
    # 4)
    # set start of 1 labels        
    label_df.loc[tups, 'ili']=0 if 'recovery' in col else 1
    label_df.loc[tups_24, 'ili_24']=0 if 'recovery' in col else 1
    label_df.loc[tups_48, 'ili_48']=0 if 'recovery' in col else 1
    
    
    
    """
    COVID Labels
    we will follow the following procedure:
    1) Get the day after date_recovery_merged
    2) Set these values to 0 for covid__diagnosed
    3) Get the index of date_onset_merged
    4) Set these values to 1 for covid__diagnosed
    5) Make the default starting value for covid__diagnosed 0
    """
        
    label_df['covid_range']=np.nan
    
    # **************************
    # 1)    
    # recall we shifted the recovey date in the ili labels
    covid_recovery_tups = list(zip(label_df.loc[(~label_df['date_recovery_merged'].isna())&(label_df['covid__diagnosed']==1), :].index.get_level_values('participant_id'),
                                   label_df.loc[(~label_df['date_recovery_merged'].isna())&(label_df['covid__diagnosed']==1), 'date_recovery_merged'].values))
    # **************************
    # 2)
    label_df.loc[covid_recovery_tups, 'covid_range']=0 # we don't need to reindex because all date recoveries are already done
    # **************************
    # 3)
    covid_onset_tups = list(zip(label_df.loc[(~label_df['date_onset_merged'].isna())&(label_df['covid__diagnosed']==1), :].index.get_level_values('participant_id'),
                                (label_df.loc[(~label_df['date_onset_merged'].isna())&(label_df['covid__diagnosed']==1), 'date_onset_merged']-pd.to_timedelta(2, unit='D')).values))
    # **************************
    # 4)
    label_df.loc[covid_onset_tups, 'covid_range']=1 # we don't need to reindex because all date recoveries are already done
    # **************************
    # 5)
    # set the min index for each participant equal to 0
    label_df = label_df.sort_index()
    tmp = label_df.reset_index()[['participant_id', 'date_survey']].groupby('participant_id').min()
    tmp_index = tmp.set_index(['date_survey'], append=True, drop=False).index
    label_df.loc[tmp_index, 'covid_range'] = label_df.loc[tmp_index, 'covid_range'].fillna(0)
    
    # **************************
    # 6) groupby and forward fill 
    label_df['covid_range'] = label_df['covid_range'].ffill()
    # step 5 is unecessary if we just did the following (but groupbys are slow)
    label_df['covid_range'] = label_df['covid_range'].groupby('participant_id').ffill() # this could potentially be redundant.
    label_df['covid_range']=label_df['covid_range'].fillna(0)
    
    # **************************
    # 7)
    # multiply the covid__diagnosed colum by the ili column to get the covid labels

    label_df['covid']=np.nan
    label_df['covid_24']=np.nan
    label_df['covid_48']=np.nan
        
    # match up with ILI
    label_df.loc[(label_df['ili']==0), 'covid']=0 # this expicitly excludes participants without symptoms. (they do not exist in the dataset)
    label_df.loc[(label_df['ili_24']==0), 'covid_24']=0
    label_df.loc[(label_df['ili_48']==0), 'covid_48']=0
    
    label_df.loc[(label_df['covid_range']==1)&(label_df['ili']==1), 'covid']=1 # this expicitly excludes participants without symptoms. (they do not exist in the dataset)
    label_df.loc[(label_df['covid_range']==1)&(label_df['ili_24']==1), 'covid_24']=1
    label_df.loc[(label_df['covid_range']==1)&(label_df['ili_48']==1), 'covid_48']=1
    
    
    
    # Now handle the categorical variables:
    for feat in ['covid__behavior__air_travel', 'covid__contact_ILI_outside_household', 'covid__contact_covid']:
        # get additional indices
        print(set(input_df[feat].values), input_df[feat].fillna('no').astype(str).str.contains('7 days').sum())
        
        index_for_feature = input_df.loc[input_df[feat].fillna('no').astype(str).str.contains('7 days')].index.tolist()
        
        result_index = index_for_feature
        # for last_7 days
        for offset in range(1, 7):
            result_index += [(item[0], item[1]-pd.to_timedelta(offset, unit='d')) for item in set(index_for_feature)]
        
        index_for_feature = input_df.loc[input_df[feat].fillna('no').astype(str).str.contains('last 14 days')].index.tolist()
        # for last 14 days
        for offset in range(7, 14):
            result_index += [(item[0], item[1]-pd.to_timedelta(offset, unit='d')) for item in set(index_for_feature)]
        
        all_inds = deepcopy(label_df.index)
        print(feat, len(set(result_index)-set(all_inds.tolist())))
        
        new_index = pd.MultiIndex.from_tuples(list(set(result_index)), names=label_df.index.names).drop_duplicates()
        
        
        label_df[feat]=0 # first set everything to 0
        label_df.loc[all_inds.intersection(new_index), feat]=1 # now set the intersecting labels to 1
    


    for col in SURVEY_ROLLING_FEATURES:
        if col in ['covid__behavior__air_travel', 'covid__contact_ILI_outside_household', 'covid__contact_covid']:
            continue
#         print()
#         print(col)
#         print(label_df[col].dtype)
        if label_df[col].dtype!=object:
            continue
        set(label_df[col].str.lower().values.tolist())
        if set(label_df[col].str.lower().values.tolist())==set([np.nan, 'no', 'yes']):
            label_df.loc[label_df[col].str.lower()=='no', col]=0
            label_df.loc[label_df[col].str.lower()=='yes', col]=1
            if col not in ['covid__symptoms__none', 'symptoms__no_symptoms', 'covid__behavior__social_distancing__did_not']:
                label_df[col]=label_df[col].fillna(0) # this assumes no contact

#         all_inds+= result_index
#         new_index = pd.MultiIndex.from_tuples(list(set(all_inds)), names=label_df.index.names).drop_duplicates()
#         label_df=label_df.reindex(new_index)
        
#         label_df[feat]=0
#         label_df.loc[result_index, feat]=1

    
    return label_df[LABEL_COLS]
    
    


    


def one_hot_encode_cols(in_df, cols):
    """
    inputs:
        in_df (pd.DataFrame): a dataframe with categorical features as strings
        cols (list): the columns to one-hot-encode
    returns:
        pd.DataFrame: a dataframe with all columns categorically encoded
    """
    df=in_df.copy()
    for col in cols:
        new_cols = pd.get_dummies(df[col], prefix=col)
        df = pd.concat((df, new_cols), axis=1)
        df=df.drop(col, axis=1)
    return df
    

dfs=load_data(args)

#todo must pass the following checks:
if not(args.all_survey):
    assert dfs['survey'].index.equals(dfs['activity'].index), "The indices are not the same"

# check that all of our necessary outcomes are loaded
assert "ili" in dfs['survey'].columns
assert "ili_24" in dfs['survey'].columns
assert "ili_48" in dfs['survey'].columns

assert "covid" in dfs['survey'].columns
assert "covid_24" in dfs['survey'].columns
assert "covid_48" in dfs['survey'].columns

# count number of covid
print('number of ili participants: ',dfs['survey']['ili'].groupby('participant_id').max().sum() )
print('number of covid participants: ',dfs['survey']['covid'].groupby('participant_id').max().sum())

In [None]:
# save dataframe to hdf
fname = 'all_daily_data.hdf'
fname = 'all_daily_data_' + 'allsurvey_'*args.all_survey + 'regular'*args.regularly_sampled +'irregular'*(not(args.regularly_sampled)) + '_merged_apr27.hdf'

path = os.path.join(args.data_dir, 'test', fname)

dfs['baseline'].to_hdf(path, 'baseline', format='table')
dfs['activity'].to_hdf(path, 'activity_raw')
dfs['survey'].to_hdf(path, 'survey')

# Store the path for the most recent run of this notebook so that each run 
# of this will update the path to be used for a given argument configuration.
if os.path.exists(DATA_PATH_DICTIONARY_FILE):
    with open(DATA_PATH_DICTIONARY_FILE, 'rb') as f:
        path_dict = pickle.load(f)
else:
    path_dict = {}

key = GET_PATH_DICT_KEY(args)

path_dict[key] = path

with open(DATA_PATH_DICTIONARY_FILE, 'wb') as f:
    pickle.dump(path_dict, f)

# pd.Series(dfs['train_participants']).to_hdf(os.path.join(args.data_dir, 'test', fname), 'train_participants')
# pd.Series(dfs['test_participants']).to_hdf(os.path.join(args.data_dir, 'test', fname), 'test_participants')


In [None]:
# load non-imputed dataframe from hdf
# fname = 'all_daily_data.hdf'
fname = 'all_daily_data_' + 'regular'*args.regularly_sampled +'irregular'*(not(args.regularly_sampled)) + '_merged3.hdf'
fname = 'all_daily_data_' + 'allsurvey_'*args.all_survey + 'regular'*args.regularly_sampled +'irregular'*(not(args.regularly_sampled)) + '_merged3.hdf'
fname = 'all_daily_data_' + 'allsurvey_'*args.all_survey + 'regular'*args.regularly_sampled +'irregular'*(not(args.regularly_sampled)) + '_merged_apr27.hdf'


dfs = {}
dfs['baseline'] = pd.read_hdf(os.path.join(args.data_dir, 'test', fname), 'baseline')
dfs['activity'] = pd.read_hdf(os.path.join(args.data_dir, 'test', fname), 'activity_raw')
dfs['survey'] = pd.read_hdf(os.path.join(args.data_dir, 'test', fname), 'survey')

<h2> Imputation </h2>
<a id="imputation">Simple imputation for GRUD</a>

In [None]:
# get means from healthy cohort
inds = dfs['survey'].loc[dfs['survey']['ili_48']==0, :].index.tolist()

train_means = dfs['activity'].loc[inds, :].sample(5000).mean()





def simple_impute(df_input, train_means, ID_COLS=['participant_id'], tqdm=tqdm):
    """ 
    Args:
        
    Returns:
        
    """
    df_in=df_input.copy()

    #masked data
    masked_df=pd.notna(df_in)
    masked_df=masked_df.apply(pd.to_numeric)

    #time since last measurement
    is_absent = (1 - masked_df).apply(pd.to_numeric)
    date_diff = np.concatenate(([0], np.diff(is_absent.index.get_level_values('date'), n=1, axis=0) / np.timedelta64(1,'D')))
    
    
    def cumsum_missing(xs):
#         print(xs)
        for col in xs.columns:
            xs_out=[]
            for i, x in enumerate(xs[col]):
                if i==0:
                    xs_out.append(0)
                elif xs[col].values[i-1]==0:
                    # data is not absent
                    xs_out.append(date_diff[i])
                else:
                    xs_out.append(date_diff[i]+xs_out[i-1])
            xs.loc[:, col]=xs_out
        return xs
    tqdm_pandas(tqdm())
    time_df=is_absent.groupby(ID_COLS).progress_apply(cumsum_missing)
    
    is_absent.loc[:, :] = is_absent.values * date_diff[:, np.newaxis]
                
    
    
#     # if current mask is 1, then date_diff else cumsum
#     hours_of_absence = is_absent.groupby(ID_COLS).cumsum()
#     time_since_measured = hours_of_absence - hours_of_absence[is_absent==0].fillna(method='ffill')
#     time_df = time_since_measured.fillna(0)
    
    
    
    

    #last observed value
    X_last_obsv_df=df_input.copy()

    # do the mean imputation for only the first hour
    columns=X_last_obsv_df.columns.tolist()
    
    #Only do means where the column isn't the outcome
    inds = X_last_obsv_df.reset_index('date').groupby('participant_id').min().set_index(['date'], append=True).index
    subset_data=X_last_obsv_df.loc[inds, columns]

    # (not sure how original paper did it, possibly just fill with zeros???)
    subset_data=subset_data.fillna(train_means)

    #replace first hour data with the imputed first hour data
    X_last_obsv_df.loc[inds, columns] = subset_data.values

    # now it is safe for forward fill
    #forward fill the rest of the sorted data
    X_last_obsv_df=X_last_obsv_df.fillna(method='ffill')
    X_last_obsv_df=X_last_obsv_df.fillna(0)
    
    # todo: now do the normalised value with the offset
    ffilled_norm_df=df_in.copy()
    num_days_history = 28 # number of days to include in the history
    num_days_minimum = 14
    norm_by_past_days = 7
    
    # add gaussian noise of std=0.1 for hr 
    ffilled_norm_df = ffilled_norm_df.transform(lambda x:(x-x.rolling(num_days_history, num_days_minimum).mean().shift(norm_by_past_days))\
                                                            /x.rolling(num_days_history, num_days_minimum).std(ddof=0).shift(norm_by_past_days))
    #six sigma
#     lambda x: x if std(x)*6<abs(x) else np.nan
    ffilled_norm_df_std = df_in.copy().transform(lambda x:x.rolling(num_days_history, num_days_minimum).std(ddof=0).shift(norm_by_past_days))
    for col in ffilled_norm_df.columns:
        ffilled_norm_df.loc[ffilled_norm_df[col].abs()>6*ffilled_norm_df_std[col], col]=np.nan
    
    ffilled_norm_df = ffilled_norm_df.groupby('participant_id').fillna(method='ffill')
    norm_df = ffilled_norm_df.copy()
    ffilled_norm_df = ffilled_norm_df.fillna(0)
    
    display(ffilled_norm_df.head(10))
    
       
    
    return df_in, norm_df, ffilled_norm_df, X_last_obsv_df, masked_df, time_df

print('imputing')
inp, norm_df, ffilled_norm_df, ffilled_df, mask_df, time_df = simple_impute(dfs['activity'], train_means)



# ffilled_df.columns = pd.MultiIndex.from_product([ffilled_df.columns, ['measurement']])
# mask_df.columns = pd.MultiIndex.from_product([mask_df.columns, ['mask']])
# time_df.columns = pd.MultiIndex.from_product([time_df.columns, ['time']])

df2 = pd.concat({'measurement_z':ffilled_norm_df.loc[ffilled_df.index, :], 'measurement':ffilled_df.loc[ffilled_df.index, :], 'measurement_noimp':inp, 'measurement_z_noimp': norm_df, 'mask':mask_df.loc[ffilled_df.index, :], 'time':time_df.loc[ffilled_df.index, :]}, axis=1, names=['df_type', 'value'])


display(df2)

df2.to_hdf(os.path.join(args.data_dir, 'test', fname), 'activity')

print('done saving')


# Add time_to_onset to survey dataframe #

<a id="time_to_onset"> </a>

In [None]:
# load dataframe from hdf, to add time_to_onset to.
# fname = 'all_daily_data.hdf'
fname = 'all_daily_data_' + 'allsurvey_'*args.all_survey + 'regular'*args.regularly_sampled +'irregular'*(not(args.regularly_sampled)) + '.hdf'


# Only need the survey dataframe to calculate this.
dfs = {}
dfs['survey'] = pd.read_hdf(os.path.join(args.data_dir, 'test', fname), 'survey')

In [None]:
# Calculate time_to_onset, add the column to the survey dataframe, save the new survey dataframe to the hdf file read from.
def calculate_tto(group_df):
    """
    Calculate time_to_onset series, time_to_onset is always the number of days closest to a date of onset for ILI, 
    negative indicates the date is before the nearest onset date. 
    
    group_df: pandas grouped dataframe, grouped by the participant_id
    """
    assert 'ili' in group_df.columns, "Missing ili column in grouped dataframe to calculate time_to_onset"
    didx = pd.Series(group_df.index.get_level_values(1))
    
    ili_diff = group_df['ili'].droplevel(0).reindex(pd.DatetimeIndex([didx.min() - pd.Timedelta('1d')] + didx.to_list() + [didx.max() + pd.Timedelta('1d')]), fill_value=0).diff()
    #drop the added dates to fix end point corner case for diff
    ili_diff = ili_diff[1:-1]
    
    onset_dates = group_df.index.get_level_values('date')[ili_diff == 1]
    
    if len(onset_dates) == 0:
        group_df['time_to_onset'] = np.nan
        return group_df

    onset_deltas = []
    for d in onset_dates:
        onset_deltas.append(pd.Series(group_df.index.get_level_values('date') - pd.to_datetime(d)))
    onset_deltas = pd.concat(onset_deltas, axis=1)
    
    # If ili is negative time_to_onset is the closest onset in either direction
    neg_time_to_onset = onset_deltas.apply(lambda x: x[np.argmin(x.abs())], axis=1).dt.days
    neg_time_to_onset.index = group_df.index
    
    # time_to_onset is the closest onset prior to the current date.
    max_delta = group_df.index.get_level_values('date').max() - group_df.index.get_level_values('date').min() + pd.to_timedelta(1, 'd')
    positive_time_to_onset = onset_deltas.copy()
    positive_time_to_onset[positive_time_to_onset < pd.to_timedelta(0, 'd')] = max_delta
    positive_time_to_onset = positive_time_to_onset.min(axis=1).dt.days
    positive_time_to_onset.index = group_df.index
    
    # if ili is positive use the backwards looking time_to_onset
    pos_idxs = group_df['ili'] == 1
    group_df.loc[pos_idxs, 'time_to_onset'] = positive_time_to_onset[pos_idxs.values]
    
    # otherwise use the closest time_to_onset
    neg_idxs = group_df['ili'] == 0
    group_df.loc[neg_idxs, 'time_to_onset'] = neg_time_to_onset[neg_idxs.values]
    
    return group_df


# assert not 'time_to_onset' in dfs['survey'].columns, 'time_to_onset already present in this survey, may not need to recalculate.'

tqdm().pandas()

dfs['survey'] = dfs['survey'].groupby('participant_id').progress_apply(calculate_tto)
dfs['survey'].to_hdf(os.path.join(args.data_dir, 'test', fname), 'survey')
