In [1]:
# optional. I'm getting annoying warnings that I just want to ignore:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# basics
import pandas as pd 
import numpy as np
import os 
import re
from datetime import datetime
from tqdm.notebook import tqdm
tqdm.pandas()
import requests
import urllib

# plotting
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
import plotly.express as px
import seaborn as sns

# modeling
from patsy import dmatrices
import statsmodels.api as sm
from sklearn.linear_model import LinearRegression
from statsmodels.sandbox.regression.gmm import IV2SLS
from statsmodels.stats.anova import anova_lm

import patsy
import sklearn.preprocessing as sklp
import statsmodels.api as sm

pd.set_option('display.max_columns', None)
pd.options.mode.chained_assignment = None

import seaborn as sb
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
num_visits_col_names = ['visits_hematopoietic_cancers', 'visits_injuries_accidents', 'visits_type_1_diabetes', 'visits_pediatric_vasculitis', 'visits_resp_cardio']


In [None]:
dummy_df2 = pd.DataFrame()
dummy_df2['school_zip'] = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]

dummy_df2['year_month'] = ['2000-01-01', '2000-02-01', '2000-03-01', '2000-04-01', '2000-01-01', '2000-02-01', '2000-03-01', '2000-04-01','2000-01-01', '2000-02-01', '2000-03-01', '2000-04-01','2000-05-01', '2000-06-01', '2000-07-01', '2000-08-01']
dummy_df2['visits_hematopoietic_cancers'] = [None, None, 1, None, None, None, None, None, None, None, None, 1, None, None, 1, 1]
dummy_df2['visits_injuries_accidents'] = [1, None, 1, None, None, None, 1, None, 1, None, None, None, 1, 1, 1, 1]

display(dummy_df2)

In [11]:
dummy_df = pd.DataFrame()
dummy_df['school_zip'] = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]

dummy_df['year_month'] = ['2000-01-01', '2000-02-01', '2000-03-01', '2000-04-01']*3
dummy_df['visits_hematopoietic_cancers'] = [None, None, 1, None, 1, None, None, None, None, None, None, None]
dummy_df['visits_injuries_accidents'] = [1, None, 1, None, None, None, None, None, 1, None, None, None]

display(dummy_df)

Unnamed: 0,school_zip,year_month,visits_hematopoietic_cancers,visits_injuries_accidents
0,1,2000-01-01,,1.0
1,1,2000-02-01,,
2,1,2000-03-01,1.0,1.0
3,1,2000-04-01,,
4,2,2000-01-01,1.0,
5,2,2000-02-01,,
6,2,2000-03-01,,
7,2,2000-04-01,,
8,3,2000-01-01,,1.0
9,3,2000-02-01,,


In [None]:
# If replace_after_1 == True:
# If there is a 1.0, replace all NaNs after with 0's

# If replace_after_1 == False:
# If there's a 1.0 at any point, replace all NaNs for that zipcode with 0's


In [22]:
def filter_nans(df, visits_cols = num_visits_col_names, replace_after_1 = True):
    """Function to generate columns in place that replace NaNs with 0's only if that 
    row occurred after the first non-zero/not null visit in that zipcode for the specific
    health outcome. Keeps them as nulls otherwise.

    Args:
        df (DataFrame): Input dataframe
        visits_cols (list, optional): list of columns to selectively filter NaNs
    Returns:
        DataFrame with columns replaced with their NaN-filtered versions
    """

    def get_rowIndex(row):
        """Function intended for applying across df rows

        Args:
            row (int): row

        Returns:
            int: index of row
        """
      
        return row.name

    def compare_and_replace(orig_visits, dataset_row_idx, school_zip):
        """Function intended for applying across df rows
         Selectively replaces NaNs with 0's
        Args:
            orig_visits: original column that needs to be filtered
            dataset_row_idx: column with row indices for the entire df
            school_zip: column with school zips

        Returns:
            float or NaN
        """
        
        # school zip + zip idx
        first_val_row_idx = dict_row_idx[school_zip]
        zip_idx = dict_zip_idx[school_zip]
        max_idx = dict_max_zipindex_per_zip[school_zip]
        difference = max_idx - zip_idx + 1

        # check the school zip first
        # If there is a 1.0, replace all NaNs after with 0's
        if replace_after_1 == True:
            if dataset_row_idx < first_val_row_idx:
                orig_visits = orig_visits
            elif (dataset_row_idx >= first_val_row_idx) and (dataset_row_idx <=  first_val_row_idx + difference):
                if pd.isnull(orig_visits):
                    orig_visits = 0
                else:
                    orig_visits = orig_visits
            return orig_visits
        # If there's a 1.0 at any point, replace all NaNs for that zipcode with 0's
        elif replace_after_1 == False:
            # no 1.0 anywhere in the zip
            if zip_idx == df_grouped_schools.shape[0]:
                orig_visits = orig_visits
            # there is a 1.0 somewhere in the zip
            else:
                if pd.isnull(orig_visits):
                    orig_visits = 0
                else: 
                    orig_visits = orig_visits
            return orig_visits

        
    # group df by school_zip, year_month
    df_grouped_schools = df.groupby(['school_zip', 'year_month']).tail(1)

    unique_school_zips = list(df_grouped_schools['school_zip'].unique())

    # generate overall row index
    df_grouped_schools['rowIndex'] = df_grouped_schools.apply(get_rowIndex, axis=1)

    # generate row indices that rest per school zip
    df_grouped_schools['zipIndex'] = df_grouped_schools.groupby(['school_zip'])['year_month'].rank('first', ascending=True).astype(int)
    df_grouped_schools['zipIndex'] = df_grouped_schools['zipIndex'] - 1

    # generate dictionary that gets max index per school zip
    dict_max_zipindex_per_zip = {}
    for i in unique_school_zips:
        dict_max_zipindex_per_zip[i] = df_grouped_schools[df_grouped_schools['school_zip']==i]['zipIndex'].max()

    for i in visits_cols:
        dict_zip_idx = {}
        dict_row_idx = {}
        for j in unique_school_zips:
            temp = df_grouped_schools[df_grouped_schools['school_zip']==j]

            visits_series = pd.Series(temp[i]) # one school zip, filtered to 1 health outcome
            bool_not_null = visits_series.notnull()
            all_indices_not_null = np.where(bool_not_null)[0]

            # save index of the first non-NaN value within the zipcode indices
            # if everything every value for zip is NaN, set value to # of records in df
            try:
                groupby_index = all_indices_not_null[0]
            except IndexError:
                groupby_index = df_grouped_schools.shape[0]
            dict_zip_idx[j] = groupby_index
            
            # save index of the row from whole dataset; set value to # of records in df if not
            try:
                row_idx = temp.loc[temp['zipIndex'] == groupby_index, 'rowIndex'].values[0]
            except IndexError:
                row_idx = df_grouped_schools.shape[0]
            dict_row_idx[j] = row_idx
        
        df_grouped_schools[i] = df_grouped_schools.apply(lambda row: compare_and_replace(row[i], row['rowIndex'], row['school_zip']), axis=1)

    # drop rowIndex and zipIndex cols
    df_grouped_schools.drop(columns=['rowIndex', 'zipIndex'], inplace=True)

    return df_grouped_schools

# call function:
# df_all = filter_nans(df_all, visits_cols = num_visits_col_names, replace_after_1 = False)
# display(df_all)

In [13]:
display(dummy_df)

Unnamed: 0,school_zip,year_month,visits_hematopoietic_cancers,visits_injuries_accidents
0,1,2000-01-01,,1.0
1,1,2000-02-01,,
2,1,2000-03-01,1.0,1.0
3,1,2000-04-01,,
4,2,2000-01-01,1.0,
5,2,2000-02-01,,
6,2,2000-03-01,,
7,2,2000-04-01,,
8,3,2000-01-01,,1.0
9,3,2000-02-01,,


In [24]:
# dummy_df2 = filter_nans(dummy_df, visits_cols=['visits_hematopoietic_cancers', 'visits_injuries_accidents'], replace_after_1 = True)
# display(dummy_df2)

In [23]:
dummy_df2 = filter_nans(dummy_df, visits_cols=['visits_hematopoietic_cancers', 'visits_injuries_accidents'], replace_after_1 = False)
display(dummy_df2)

Unnamed: 0,school_zip,year_month,visits_hematopoietic_cancers,visits_injuries_accidents
0,1,2000-01-01,0.0,1.0
1,1,2000-02-01,0.0,0.0
2,1,2000-03-01,1.0,1.0
3,1,2000-04-01,0.0,0.0
4,2,2000-01-01,1.0,
5,2,2000-02-01,0.0,
6,2,2000-03-01,0.0,
7,2,2000-04-01,0.0,
8,3,2000-01-01,,1.0
9,3,2000-02-01,,0.0
