# CMS Diagnosis Categories

### High level process overview

Deleted

### Data Dictionary

**Field Name** | Description  

deleted

### Data Flow

Source Table(s): medical claims 
Reference Tables(s): **cms_mcc_xref_2017**  
First GroupBy: "Member_Key","Clami_Key"  
Second GroupBy: "Member_Key"  
Join Required: Membership dataframe (fields needed: active months)  


Note: **Bold** indicates newly created table

##### Expects the following variables to exist:
all_members, direction, interval_values, interval_flag

##### Expects the following dataframes to exist:
medclms_df, mbrs_df

##### Caches the following dataframes:
cms_diag_df

For Intervals:
> cms_diag_df1-6 (based on intervals defined)  
> cms_diag_trend_df (based on intervals defined)  

### Initialization and prerequisite checks

In [None]:
script_start = time.time()

try:
    cms_diag_df
except:
    pass
else:
    cms_diag_df.unpersist()
    
if interval_flag:
    try:
        cms_diag_df1
    except:
        pass
    else:
        cms_diag_df1.unpersist()
    try:
        cms_diag_df2
    except:
        pass
    else:
        cms_diag_df2.unpersist()
    try:
        cms_diag_df3
    except:
        pass
    else:
        cms_diag_df3.unpersist()
    try:
        cms_diag_df4
    except:
        pass
    else:
        cms_diag_df4.unpersist()
    try:
        cms_diag_df5
    except:
        pass
    else:
        cms_diag_df5.unpersist()
    try:
        cms_diag_df6
    except:
        pass
    else:
        cms_diag_df6.unpersist()
    try:
        cms_diag_trend_df
    except:
        pass
    else:
        cms_diag_trend_df.unpersist()

In [None]:
# Pre-requisite check

try:
    all_members
except:
    raise HaltException("all_members not defined - please set as boolean type (True/False)")
    print str(HaltException)
else:
    try:
        direction in ["forward","backward"]
    except:
        raise HaltException("direction not properly defined - please set to 'forward' or 'backward'")
        print str(HaltException)
    else:
        try:
            medclms_df
        except:
            raise HaltException("Medclaims dataframe not created - please run 'Base MedClaims' first")
            print str(HaltException)
        else:
            if all_members:
                try:
                    cms_diag_mbrs_df = mbrs_df
                except:
                    raise HaltException("Member reference list not created - please run 'Base Membership' first")
                    print str(HaltException)
            else:
                try:
                    cms_diag_mbrs_df = atrisk_mbrs_df
                except:
                    raise HaltException("At risk member reference list not created - please create atrisk_mbrs_df first")
                    print str(HaltException)
            try:
                interval_values
                interval_flag
            except:
                raise HaltException("interval_flag and interval_values not defined - please set")
                print str(HaltException)
            
print "Success! All prerequisites met for CMS Diagnosis Category Features!"
print ""

### CMS Diagnosis Category Specific Global Variables and Reference Tables

In [None]:
# Uses SparkSQL to load CMS cross reference table from Hive

start = time.time()

#Limited to ICD10 due to differing categories across versions
# Exact query is deleted

temp_cms_xref_df = hive_context.sql(
    """
    ***
    """)


# Creates a dictionary by ICD Version and Diagnosis code to reference specific CMS Category Data
cms_diag_xref = (temp_cms_xref_df.withColumn("unique_id",f.concat('icd_diag_vers_cd','diag_cd'))
                                 .na.fill('NA')
                                 .toPandas().set_index("unique_id").T.to_dict())

# Generates master lists of CMS Level 1 and 2 Category descriptions
cmsd1_desc_list = sorted(list(set([i['cms_desc'] for i in cms_diag_xref.values() if i['cms_desc'] != 'NA'])))
cmsd2_desc_list = sorted(list(set([i['cms_desc'] + '_' + i['sub_cms_desc'] for i in cms_diag_xref.values()
                                     if (i['cms_desc'] != 'NA') & (i['sub_cms_desc'] != 'NA')])))

try:
    temp_cms_xref_df
    cms_diag_xref
    cmsd1_desc_list
    cmsd2_desc_list
except:
    raise HaltException("Something went wrong... failed to create at least one of the reference lists or variables")
    print str(HaltException)
else:
    print "CMS Diagnosis Category Reference Lists successfully created!"
    print ""

### CMS Diagnosis Category Specific Supporting Python Functions

In [None]:
#Note: "cms_feature" input is currently limited to "cms_desc" and "sub_cms_desc"

print "Loading CMS Diagnosis Category specific functions..."
print ""

#Note: Modified general functions "flat_encode_ct" and "array_pmpm" to support cms and sub_cms reference lists

# Replaces diagnosis code with corresponding CMS Level 1 or 2 Category
def cms_data (icd_ver, dx1, dx2, dx3, dx4, dx5, dx6, dx7, dx8, dx9, cms_feature):
    diag_icd_data = [icd_ver+n for n in [dx1, dx2, dx3, dx4, dx5, dx6, dx7, dx8, dx9]]
    if cms_feature == 'sub_cms_desc':
        cms_data = list(set([str(cms_diag_xref[n]['cms_desc']+'_'+cms_diag_xref[n][cms_feature])
                             if n in cms_diag_xref else 'NA' for n in diag_icd_data]))
    else:
        cms_data = list(set([str(cms_diag_xref[n][cms_feature]) if n in cms_diag_xref else 'NA' for n in diag_icd_data]))
    return cms_data

# Start of code specific to running CMS Diagnosis Categories Macro

In [None]:
# Creates features by CMS Category Types in array equal to length of unique values among CMS Level 1 & 2 categories

print "Generating CMS Diagnosis Category related features..."
print ""

#Function calls at claim line level
get_cms = udf(cms_data, ArrayType(StringType()))

#Function calls at logical claim level
get_flat_ct = udf(flat_encode_ct, ArrayType(IntegerType()))

#Function calls at member level
get_agg_arrays = udf(agg_arrays, ArrayType(IntegerType()))

#Function calls after joining member data
get_array_pmpm = udf(array_pmpm, ArrayType(FloatType()))

#Function explodes elements of each row into a giant single tuple
def extract_cmsd(row):
    #Since not all members have claims, have to handle missing data & generate tuple with Member_Key and zeros of correct length
    if row.cmsd1_ind is None:
        blank_cmsd1_int = [0] * len(cmsd1_desc_list)
        blank_cmsd2_int = [0] * len(cmsd2_desc_list)
        blank_cmsd1 = [0.0] * (len(cmsd1_desc_list)*2)
        blank_cmsd2 = [0.0] * len(cmsd2_desc_list)
        return ((row.Member_Key2, ) + tuple(blank_cmsd1_int, ) + tuple(blank_cmsd1, ) + 
                tuple(blank_cmsd2_int, ) + tuple(blank_cmsd2))
    #If data present, just extract it into a single tuple
    else:
        return ((row.Member_Key2, ) + tuple(row.cmsd1_ind, ) + tuple(row.pmpm_cmsd1_ct, ) + tuple(row.cmsd1_pct, ) +
                 tuple(row.cmsd2_ind, ) + tuple(row.pmpm_cmsd2_ct))
    
def extract_interval_cmsd(row):
    #Since not all members have claims, have to handle missing data & generate tuple with Member_Key and zeros of correct length
    if row.pmpm_cmsd1_ct is None:
        blank_cmsd1 = [0.0] * len(cmsd1_desc_list)
        return (row.Member_Key2, ) + tuple(blank_cmsd1)
    #If data present, just extract it into a single tuple
    else:
        return (row.Member_Key2, ) + tuple(row.pmpm_cmsd1_ct)

if interval_values == []:
    interval_values = [0]
    
if interval_flag:
    if max(interval_values) > history_lookback:
        raise HaltException("Interval values exceed history period - please set last value to length of history period")
        print str(HaltException) 
    elif max(interval_values) < history_lookback:
        raise HaltException("""Interval values all smaller than history period - please set last value to length of 
        history period""")
        print str(HaltException) 
    elif min(interval_values) != 0:
        raise HaltException("First interval value must be zero")
        print str(HaltException)
    elif len(interval_values) > 7:
        raise HaltException("Too many intervals!  Max length of 7 numbers, starting with zero and ending with history_lookback")
        print str(HaltException)
    elif len(interval_values) < 3:
        raise HaltException("Too few intervals!  Min length of 3 numbers, starting with zero and ending with history_lookback")
        print str(HaltException)
    else:
        print ("Interval features will be calculated during history_lookback period by " + interval_type +
               " according to this pattern " + str(interval_values))
    cmsd_df_list = []
    cmsd_suffix_list = []
    stop_range = len(interval_values)
else:
    stop_range = 1
    
for step in range(stop_range):
    if interval_values[step] == 0:    
        # Creates count of claims by CMS Level 1 & 2 at the member level along with an overall claim count (logical claims)
        temp_cms_diag1 = (medclms_df.withColumn("cms_desc",get_cms("ICD_SRC_VERS_CD","PRIMARY_DIAG_CD","DIAG_CD2","DIAG_CD3",
                                                "DIAG_CD4","DIAG_CD5","DIAG_CD6","DIAG_CD7","DIAG_CD8","DIAG_CD9",
                                                                   f.lit("cms_desc")))
                                    .withColumn("sub_cms_desc",get_cms("ICD_SRC_VERS_CD","PRIMARY_DIAG_CD","DIAG_CD2",
                                                                       "DIAG_CD3","DIAG_CD4","DIAG_CD5","DIAG_CD6","DIAG_CD7",
                                                                       "DIAG_CD8","DIAG_CD9",f.lit("sub_cms_desc")))
                                    .groupBy("Member_Key","Clami_Key").agg(f.collect_list("cms_desc").alias("cms_desc"),
                                                                          f.collect_list("sub_cms_desc").alias("sub_cms_desc"))
                                    .withColumn("cmsd1_cts",get_flat_ct("cms_desc",f.array(map(f.lit, cmsd1_desc_list))))
                                    .withColumn("cmsd2_cts",get_flat_ct("sub_cms_desc",f.array(map(f.lit, cmsd2_desc_list))))
                                    .groupBy("Member_Key").agg(f.collect_list("cmsd1_cts").alias("cmsd1_cts"),
                                                        f.collect_list("cmsd2_cts").alias("cmsd2_cts"),
                                                        f.count("Clami_Key").alias("clm_cnt"))
                                    .withColumn("cmsd1_ct_max",get_agg_arrays("cmsd1_cts",f.lit(False)))
                                    .withColumn("cmsd2_ct_max",get_agg_arrays("cmsd2_cts",f.lit(False)))
                                    .withColumn("cmsd1_ind",get_agg_arrays("cmsd1_cts",f.lit(True)))
                                    .withColumn("cmsd2_ind",get_agg_arrays("cmsd2_cts",f.lit(True)))
                                    .select('Member_Key','cmsd1_ct_max','cmsd1_ind','cmsd2_ct_max','cmsd2_ind','clm_cnt'))

        # Adds PMPM count and percentage of claims by CMS Level 1 & 2 Categories
        if direction == "backward":
            pmpm_col = "pre_month_total"
        else:
            pmpm_col = "post_month_total"

        # Select needed information at end before using RDD conversion to split into individual columns
        # Converts to RDD to extract each row into a single tuple, and then converts back to DF
        # Make sure order of columns in select statement matches header for new DF created from RDD
        temp_cms_diag2 = (cms_diag_mbrs_df.join(temp_cms_diag1,temp_cms_diag1.Member_Key == cms_diag_mbrs_df.Member_Key2,"left_outer")
                                          .select([cms_diag_mbrs_df["Member_Key2"],cms_diag_mbrs_df[pmpm_col]] +
                                                  [c for c in temp_cms_diag1.columns if c not in ['Member_Key']])
                                          .withColumn("pmpm_cmsd1_ct",get_array_pmpm("cmsd1_ct_max",col(pmpm_col),
                                                                                     f.array(map(f.lit, cmsd1_desc_list))))
                                          .withColumn("pmpm_cmsd2_ct",get_array_pmpm("cmsd2_ct_max",col(pmpm_col),
                                                                                     f.array(map(f.lit, cmsd2_desc_list))))
                                          .withColumn("cmsd1_pct",get_array_pmpm("cmsd1_ct_max",(col("clm_cnt")/col(pmpm_col)),
                                                                                 f.array(map(f.lit, cmsd1_desc_list))))
                                          .select("Member_Key2","cmsd1_ind","pmpm_cmsd1_ct","cmsd1_pct","cmsd2_ind",
                                                  "pmpm_cmsd2_ct"))

        #Creates schema rather than accepting default inferred schema to improve storage/memory efficiency
        #Uses cmsd_desc_lists to help generate column names
        cmsd_schema = StructType([StructField("Member_Key_cms_diag",LongType(),True)] +
                   [StructField("cmsd1_"+cmsd1_desc_list[i]+"_ind", ByteType(), True) for i in range(len(cmsd1_desc_list))] +
                   [StructField("cmsd1_"+cmsd1_desc_list[i]+"_pmpm_ct", FloatType(), True)
                    for i in range(len(cmsd1_desc_list))] +
                   [StructField("cmsd1_"+cmsd1_desc_list[i]+"_pct", FloatType(), True) for i in range(len(cmsd1_desc_list))] +
                   [StructField("cmsd2_"+cmsd2_desc_list[i]+"_ind", ByteType(), True) for i in range(len(cmsd2_desc_list))] +
                   [StructField("cmsd2_"+cmsd2_desc_list[i]+"_pmpm_ct", FloatType(), True)
                    for i in range(len(cmsd2_desc_list))])

        # Converts existing DF to RDD and extracts data from array fields, then converts back to DF with specified schema
        cms_diag_df = (hive_context.createDataFrame(temp_cms_diag2.rdd.map(extract_cmsd),cmsd_schema)
                                   .na.fill(0.0).na.fill(0).cache())

        try:
            print "\t" + "Total Unique Members: " + str(cms_diag_df.count())
            print "\t" + "Total Features: " + str(len(cms_diag_df.columns)-1)
        except:
            raise HaltException("Something went wrong... cms_diag_df not successfully created - see PuTTY console for details")
            print str(HaltException)
    else:
        print "Generating Features for Interval " + str(step) + "..."
        if score_date_type == "fixed":
            cmsdclms_df = (medclms_df.filter((f.to_date(medclms_df.SERV_FROM_DATE, "yyyy-MM-dd") > 
                                             f.add_months(f.to_date(f.lit(score_date)),-1*interval_values[step])) &
                                             (f.to_date(medclms_df.SERV_FROM_DATE, "yyyy-MM-dd") <= 
                                             f.add_months(f.to_date(f.lit(score_date)),-1*interval_values[step-1])) &
                                             (f.to_date(medclms_df.PROCESS_DATE, "yyyy-MM-dd") > 
                                             f.add_months(f.to_date(f.lit(score_date)),-1*interval_values[step])) &
                                             (f.to_date(medclms_df.PROCESS_DATE, "yyyy-MM-dd") <= 
                                             f.add_months(f.to_date(f.lit(score_date)),-1*interval_values[step-1]))))
        else:
            cmsdclms_df = (medclms_df.filter((f.to_date(medclms_df.SERV_FROM_DATE, "yyyy-MM-dd") > 
                                      f.add_months(f.to_date(medclms_df.INDEX_DATE,"yyyy-MM-dd"),-1*interval_values[step])) &
                                      (f.to_date(medclms_df.SERV_FROM_DATE, "yyyy-MM-dd") <= 
                                      f.add_months(f.to_date(medclms_df.INDEX_DATE,"yyyy-MM-dd"),-1*interval_values[step-1])) &
                                      (f.to_date(medclms_df.PROCESS_DATE, "yyyy-MM-dd") > 
                                      f.add_months(f.to_date(medclms_df.INDEX_DATE,"yyyy-MM-dd"),-1*interval_values[step])) &
                                      (f.to_date(medclms_df.PROCESS_DATE, "yyyy-MM-dd") <= 
                                      f.add_months(f.to_date(medclms_df.INDEX_DATE,"yyyy-MM-dd"),-1*interval_values[step-1]))))
        
        # Creates count of claims by CMS Level 1 & 2 at the member level along with an overall claim count (logical claims)
        temp_cms_diag1 = (cmsdclms_df.withColumn("cms_desc",get_cms("ICD_SRC_VERS_CD","PRIMARY_DIAG_CD","DIAG_CD2","DIAG_CD3",
                                                "DIAG_CD4","DIAG_CD5","DIAG_CD6","DIAG_CD7","DIAG_CD8","DIAG_CD9",
                                                                   f.lit("cms_desc")))
                                    .groupBy("Member_Key","Clami_Key").agg(f.collect_list("cms_desc").alias("cms_desc"))
                                    .withColumn("cmsd1_cts",get_flat_ct("cms_desc",f.array(map(f.lit, cmsd1_desc_list))))
                                    .groupBy("Member_Key").agg(f.collect_list("cmsd1_cts").alias("cmsd1_cts"),
                                                        f.count("Clami_Key").alias("clm_cnt"))
                                    .withColumn("cmsd1_ct_max",get_agg_arrays("cmsd1_cts",f.lit(False)))
                                    .select('Member_Key','cmsd1_ct_max','clm_cnt'))

        # Adds PMPM count and percentage of claims by CMS Level 1 & 2 Categories
        pmpm_col = "pre_interval_totals"
        pmpm_val = len(interval_values)-step-1

        # Select needed information at end before using RDD conversion to split into individual columns
        # Converts to RDD to extract each row into a single tuple, and then converts back to DF
        # Make sure order of columns in select statement matches header for new DF created from RDD
        temp_cms_diag2 = (cms_diag_mbrs_df.join(temp_cms_diag1,temp_cms_diag1.Member_Key == cms_diag_mbrs_df.Member_Key2,"left_outer")
                                          .select([cms_diag_mbrs_df["Member_Key2"],cms_diag_mbrs_df[pmpm_col]] +
                                                  [c for c in temp_cms_diag1.columns if c not in ['Member_Key']])
                                          .withColumn("pmpm_cmsd1_ct",get_array_pmpm("cmsd1_ct_max",col(pmpm_col)[pmpm_val],
                                                                                     f.array(map(f.lit, cmsd1_desc_list))))
                                          .select("Member_Key2","pmpm_cmsd1_ct"))
        
        #Creates schema rather than accepting default inferred schema to improve storage/memory efficiency
        #Uses cmsd_desc_list to help generate column names
        cmsd_interval_schema = StructType([StructField("Member_Key_cms_diag",LongType(),True)] +
                                          [StructField("cmsd1_"+cmsd1_desc_list[i]+"_pmpm_ct", FloatType(), True)
                                           for i in range(len(cmsd1_desc_list))])
        
        if interval_type == "months":
            int_suffix = "m_b4"
        else:
            int_suffix = "d_b4"
            
        cmsd_features_temp = (hive_context.createDataFrame(temp_cms_diag2.rdd.map(extract_interval_cmsd),cmsd_interval_schema)
                                   .na.fill(0.0).na.fill(0).cache())
        
        cmsd_features_list = ([str(c)+"_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix 
                                  for c in cmsd_features_temp.columns if c.endswith("_pmpm_ct")])
        
        if step == 1:
            col_list = ["Member_Key_cmsd1"] + cmsd_features_list
            cms_diag_df1 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df1.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df1")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)
        elif step == 2:
            col_list = ["Member_Key_cmsd2"] + cmsd_features_list
            cms_diag_df2 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df2.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df2")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)
        elif step == 3:
            col_list = ["Member_Key_cmsd3"] + cmsd_features_list
            cms_diag_df3 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df3.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df3")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)
        elif step == 4:
            col_list = ["Member_Key_cmsd4"] + cmsd_features_list
            cms_diag_df4 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df4.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df4")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)
        elif step == 5:
            col_list = ["Member_Key_cmsd5"] + cmsd_features_list
            cms_diag_df5 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df5.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df5")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)
        elif step == 6:
            col_list = ["Member_Key_cmsd6"] + cmsd_features_list
            cms_diag_df6 = cmsd_features_temp.toDF(*col_list).cache()
            cms_diag_df6.count() #Forces DF to be cached
            cmsd_df_list.append("cms_diag_df6")
            cmsd_suffix_list.append("_"+str(interval_values[step-1])+"to"+str(interval_values[step])+int_suffix)

In [None]:
if interval_flag:
    if len(cmsd_df_list) != len(interval_values)-1:
        print ("""Mixed Results!  The 'cms_diag_df' feature dataframe was created and cached, but at least some
        interval dataframes were not!""")
        print "Expected " + str(len(interval_values)-1) + ", but only " + str(len(cmsd_df_list)) + " were created!"
        print "Interval dataframes that were created are named the following:"
        for x in cmsd_df_list:
            print str(x)
        script_stop = time.time()
        print '\t' + str(script_stop-script_start) + 's'
    else:
        print ("Success!  The 'cms_diag_df' feature dataframe and " + str(len(cmsd_df_list)) + 
               " interval dataframes were created and cached")
        print "Interval dataframes that were created are named the following:"
        for x in cmsd_df_list:
            print str(x)
        script_stop = time.time()
        print '\t' + str(script_stop-script_start) + 's'
else:
    print "Success!  The 'cms_diag_df' feature dataframe was created and cached"
    script_stop = time.time()
    print '\t' + str(script_stop-script_start) + 's'

In [None]:
if interval_flag:
    script_start = time.time()
    print ""
    print "Generating Trend Features..."
    print ""
    
    df_list_length = len(cmsd_df_list)
    suffix_list = cmsd_suffix_list[::-1]
    feature_list = [c for c in cms_diag_df.columns if c.startswith("cmsd1_") & c.endswith("_pmpm_ct")]
    
    for step in range (df_list_length-1):
        if step == 0:
            temp_trend_df = (cms_diag_df1.join(cms_diag_df2,cms_diag_df2.Member_Key_cmsd2 == cms_diag_df1.Member_Key_cmsd1)
                                              .drop(cms_diag_df2.Member_Key_cmsd2))
        elif step == 1:
            temp_trend_df1 = (temp_trend_df.join(cms_diag_df3,cms_diag_df3.Member_Key_cmsd3 == temp_trend_df.Member_Key_cmsd1)
                                           .drop(cms_diag_df3.Member_Key_cmsd3))
        elif step == 2:
            temp_trend_df2 = (temp_trend_df1.join(cms_diag_df4,cms_diag_df4.Member_Key_cmsd4 == temp_trend_df1.Member_Key_cmsd1)
                                            .drop(cms_diag_df4.Member_Key_cmsd4))
        elif step == 3:
            temp_trend_df3 = (temp_trend_df2.join(cms_diag_df5,cms_diag_df5.Member_Key_cmsd5 == temp_trend_df2.Member_Key_cmsd1)
                                            .drop(cms_diag_df5.Member_Key_cmsd5))
        elif step == 4:
            temp_trend_df4 = (temp_trend_df3.join(cms_diag_df6,cms_diag_df6.Member_Key_cmsd6 == temp_trend_df3.Member_Key_cmsd1)
                                            .drop(cms_diag_df6.Member_Key_cmsd6))
    
    try:
        temp_trend_df4
    except:
        try:
            temp_trend_df3
        except:
            try:
                temp_trend_df2
            except:
                try:
                    temp_trend_df1
                except:
                    try:
                        temp_trend_df
                    except:
                        raise HaltException("Something went wrong... no trend dataframes created")
                        print str(HaltException)
                    else:
                        final_trend_df = temp_trend_df
                else:
                    final_trend_df = temp_trend_df1
            else:
                final_trend_df = temp_trend_df2
        else:
            final_trend_df = temp_trend_df3
    else:
        final_trend_df = temp_trend_df4
        
    for x in feature_list:
        for y in range(len(interval_values)-2):

            col_string = (str(x) + "_T_" + str(interval_values[-y-1]) + "-" + str(interval_values[-y-2]) + "-" + 
                          str(interval_values[-y-3]) + int_suffix)

            final_trend_df = final_trend_df.withColumn(col_string,f.when(((col(str(x)+suffix_list[y]) == 0.0) &
                                                                         (col(str(x)+suffix_list[y+1]) == 0.0)),"No Activity")
                                                                   .when(((col(str(x)+suffix_list[y]) == 0.0) &
                                                                         (col(str(x)+suffix_list[y+1]) > 0.0)),"New")
                                                                   .when(((col(str(x)+suffix_list[y]) > 0.0) &
                                                                         (col(str(x)+suffix_list[y+1]) == 0.0)),"Resolved")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))) < 0.125,
                                                                         "Dec_over_8x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y])))
                                                                         .between(0.125,0.25),"Dec_4x-8x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))).between(0.25,0.5),
                                                                          "Dec_2x-4x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))).between(0.5,0.99),
                                                                          "Dec_1x-2x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))).between(1.01,2.0),
                                                                          "Inc_1x-2x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))).between(2.0,4.0),
                                                                          "Inc_2x-4x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))).between(4.0,8.0),
                                                                          "Inc_4x-8x")
                                                                   .when((col(str(str(x)+suffix_list[y+1]))/
                                                                         col(str(str(x)+suffix_list[y]))) > 8.0,
                                                                         "Inc_over_8x")
                                                                   .otherwise("No_Change"))
    
    cms_diag_trend_df = (final_trend_df.select([col("Member_Key_cmsd1").alias("Member_Key_cmsd_t")] +
                                                    [c for c in final_trend_df.columns if '_T_' in c]).cache())
    
    try:
        print "\t" + "Total Unique Members: " + str(cms_diag_trend_df.count())
        print "\t" + "Total Trend Features: " + str(len(cms_diag_trend_df.columns)-1)
    except:
        raise HaltException("""Something went wrong... cms_diag_trend_df not successfully created - see PuTTY console
        for details""")
        print str(HaltException)
    else:
        print "Success!  The 'cms_diag_trend_df' feature dataframe was created and cached"
        script_stop = time.time()
        print '\t' + str(script_stop-script_start) + 's'