# NSDUH Drug Sequence Analysis Part 5 v2:  Constructing Clusters from Stability
## Matthew J. Beattie
## University of Oklahoma
__February 8, 2022__

### Stability Groups
In step 4b, we created a list of stable pairs of observations.  These are observations that were paired together in all of the runs of the KNN sampled clustering techique.  In this script, we reassemble those pairs into clusters and analyze them.

### Approach
* Read in the node-pair dataset.  Going forward, we will refer to this as an edge list.
* Determine the number of connected components in the graph.  Do this for each level of minimum stability in the graph.  Use NetworkX and log to mlflow.

In [0]:
# Import pyspark libraries
from pyspark.sql import functions as f
from pyspark.sql import SparkSession, DataFrameWriter as dfw
from pyspark.sql.types import *
from pyspark.sql.functions import udf
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, ArrayType

# Import libraries for GraphFrames
import graphframes as gf

# Import standard Python libraries
from os.path import abspath
import copy
import os
import sys
import pathlib, itertools
import time
import random
import pickle
import json
import mlflow
import mlflow.sklearn
from collections import Counter
import profile
import gc
import csv
from matplotlib import pyplot as plt


# Initialize Spark session
spark = SparkSession\
    .builder\
    .config("spark.databricks.delta.retentionDurationCheck.enabled", "false")\
    .enableHiveSupport()\
    .getOrCreate()

# You have to set a checkpoint directory to run the connected components algorithm
sc.setCheckpointDir(dirName="/FileStore")

# Set Azure parameters
blob_account_name = "abuseseqstorage"
blob_container_name = "datafiles"
blob_sas_token = 'sv=2020-08-04&st=2022-02-05T21%3A37%3A40Z&se=2022-04-05T20%3A37%3A00Z&sr=c&sp=racwdl&sig=8bnXmCYRpvR93dN7eN1%2B8v%2F7cXD2dXH5z2Fus3vNSVc%3D'
spark.conf.set('fs.azure.sas.%s.%s.blob.core.windows.net' % (blob_container_name, blob_account_name), blob_sas_token)

# Set miscellaneous parameters
FIGW = 12
FIGH = 5
FONTSIZE = 8
FIGURESIZE = (FIGW,FIGH)

plt.rcParams['figure.figsize'] = (FIGW, FIGH)
plt.rcParams['font.size'] = FONTSIZE

plt.rcParams['xtick.labelsize'] = FONTSIZE
plt.rcParams['ytick.labelsize'] = FONTSIZE


### Read in AFU vector and demographic data

In [0]:
df2 = spark.sql("""
    select * from abuse_sequence.clustsamp
""").drop('labels', 'labels_0', 'labels_1', 'labels_2', 'labels_3', 'labels_4', 'labels_5', 'labels_6',
          'labels_7', 'labels_8', 'labels_9', 'labels_10', 'labels_11', 'labels_12', 'labels_13',
          'labels_14', 'labels_15', 'labels_16', 'labels_17', 'labels_18', 'labels_19')
print('Count of df2:', df2.count())
display(df2)

RESPID,AFUVECT,YRWEIGHT
201629587143.0,"List(0, 991, 14, 991, 991, 991, 991, 991, 991, 991)",1479.1442
201675987143.0,"List(0, 18, 14, 18, 23, 991, 23, 991, 991, 991)",613.19666
201643328143.0,"List(0, 18, 19, 991, 991, 991, 991, 991, 991, 991)",512.1833
201630438143.0,"List(0, 18, 991, 991, 991, 991, 991, 991, 991, 991)",740.19476
201668869143.0,"List(0, 18, 21, 991, 991, 991, 991, 991, 991, 991)",1397.463
201699371143.0,"List(0, 12, 13, 16, 991, 991, 991, 991, 991, 991)",790.9822
201649893143.0,"List(0, 21, 17, 22, 991, 991, 991, 991, 991, 991)",135.97232
201629223143.0,"List(0, 35, 18, 991, 991, 991, 991, 991, 991, 991)",3499.109
201679504143.0,"List(0, 16, 16, 18, 991, 991, 991, 991, 991, 991)",6434.9214
201625474143.0,"List(0, 18, 991, 21, 22, 991, 991, 991, 991, 991)",1722.7317


In [0]:
# Read in demographic data
# Point to files in blob storage
demogcsv = 'wasbs://%s@%s.blob.core.windows.net/%s' % (blob_container_name, blob_account_name, 'dfdemog.txt')
spark.conf.set('fs.azure.sas.%s.%s.blob.core.windows.net' % (blob_container_name, blob_account_name), blob_sas_token)
print('Remote blob path: ' + demogcsv)

# Read in cluster labels data
clust_schema = StructType([
    StructField("RESPID", StringType(), False),
    StructField("CATAG6", IntegerType(), False),
    StructField("SVCFLAG", IntegerType(), False),
    StructField("IRSEX", IntegerType(), False),
    StructField("IRMARIT", IntegerType(), False),
    StructField("NEWRACE2", IntegerType(), False),
    StructField("EDUHIGHCAT", IntegerType(), False),
    StructField("IRWRKSTAT", IntegerType(), False),
    StructField("GOVTPROG", IntegerType(), False),
    StructField("INCOME", IntegerType(), False),
    StructField("COUTYP4", IntegerType(), False),
    StructField("AIIND02", IntegerType(), False),
    StructField("YRWEIGHT", FloatType(), False),
    StructField("labels", IntegerType(), False)
])

dfdemog = spark.read.load(demogcsv, format="csv", sep="\t", schema=clust_schema, header="true")
df3 = dfdemog.drop('YRWEIGHT','labels')

# Convert AFUVECT from string to array of integers
display(df3)
print('The number of demographic observations is', df3.count())

RESPID,CATAG6,SVCFLAG,IRSEX,IRMARIT,NEWRACE2,EDUHIGHCAT,IRWRKSTAT,GOVTPROG,INCOME,COUTYP4,AIIND02
201611635143.0,3,0,2,1,1,4,4,2,4,3,2
201635755143.0,4,0,1,1,7,1,3,2,2,1,2
201692675143.0,6,0,2,1,1,3,2,2,3,1,2
201659596143.0,3,0,1,1,5,4,4,2,2,2,2
201641106143.0,5,0,1,2,1,2,4,2,2,2,2
201696416143.0,4,0,1,1,1,4,1,2,4,2,2
201673716143.0,3,0,1,4,2,2,1,1,1,2,2
201676226143.0,2,0,1,4,1,2,1,1,2,3,2
201661056143.0,4,0,1,1,7,1,1,1,3,2,2
201683666143.0,3,0,2,1,1,4,1,2,4,2,2


In [0]:
# Join the cluster and demographic tables together
dfall = df2.join(df3, ['RESPID']).withColumnRenamed('RESPID', 'id')
display(dfall)
print('The number of AFU vector and demographic observations is', dfall.count())

id,AFUVECT,YRWEIGHT,CATAG6,SVCFLAG,IRSEX,IRMARIT,NEWRACE2,EDUHIGHCAT,IRWRKSTAT,GOVTPROG,INCOME,COUTYP4,AIIND02
201629587143.0,"List(0, 991, 14, 991, 991, 991, 991, 991, 991, 991)",1479.1442,4,0,2,1,7,4,4,2,4,1,2
201675987143.0,"List(0, 18, 14, 18, 23, 991, 23, 991, 991, 991)",613.19666,3,0,2,1,1,4,4,2,4,1,2
201643328143.0,"List(0, 18, 19, 991, 991, 991, 991, 991, 991, 991)",512.1833,2,0,1,4,1,2,3,2,3,2,2
201630438143.0,"List(0, 18, 991, 991, 991, 991, 991, 991, 991, 991)",740.19476,2,0,2,4,1,1,4,2,2,3,2
201668869143.0,"List(0, 18, 21, 991, 991, 991, 991, 991, 991, 991)",1397.463,4,0,2,1,1,3,2,2,3,1,2
201699371143.0,"List(0, 12, 13, 16, 991, 991, 991, 991, 991, 991)",790.9822,4,0,2,3,1,4,3,2,1,2,2
201649893143.0,"List(0, 21, 17, 22, 991, 991, 991, 991, 991, 991)",135.97232,3,0,2,1,1,3,1,2,3,3,2
201629223143.0,"List(0, 35, 18, 991, 991, 991, 991, 991, 991, 991)",3499.109,6,0,2,1,1,2,4,2,2,2,2
201679504143.0,"List(0, 16, 16, 18, 991, 991, 991, 991, 991, 991)",6434.9214,5,0,1,1,7,4,1,2,4,2,2
201625474143.0,"List(0, 18, 991, 21, 22, 991, 991, 991, 991, 991)",1722.7317,4,0,1,4,1,2,1,1,1,3,2


### Read in tuple data from step 4b and reformat for use by GraphFrames

In [0]:
# Read in tuple set for connected component analysis
dfedges = spark.sql("""select * from abuse_sequence.sparktuplestability""")\
               .withColumnRenamed('orignode', 'src')\
               .withColumnRenamed('termnode', 'dst').cache()
display(dfedges)

# Generate number of tuples for specific stability values
columns = ['c','stability','tuplecount','cumtuplecount']
tuplestats = []

# Get tuple counts for lowest stability thresholds
B = 20
cumtuplecount = 0
for i in range(20,0,-1):
    # Restict edges to those that show up at least i times
    stab = i/B
    tuplecount = dfedges.filter(f.col('tottuples')==i).count()
    cumtuplecount += tuplecount
    print('The number of tuples for stability =', stab, 'is', tuplecount, 'cumulative count is', cumtuplecount)
    tuplestats.append((i, stab, tuplecount, cumtuplecount))
    
# Save list to a dataframe
dftupledesc = spark.createDataFrame(tuplestats, columns)
display(dftupledesc)

src,dst,tottuples,stability
201637100920.0,201737554619.0,20,1.0
201637100920.0,201831192363.0,20,1.0
201637100920.0,201645726537.0,13,0.65
201637100920.0,201925271678.0,11,0.55
201637100920.0,201723826514.0,13,0.65
201637100920.0,201639303540.0,11,0.55
201669499330.0,201699372019.0,20,1.0
201669499330.0,201753470655.0,20,1.0
201669499330.0,201973936715.0,11,0.55
201669499330.0,201697435089.0,13,0.65


c,stability,tuplecount,cumtuplecount
20,1.0,105576279,105576279
19,0.95,8490053,114066332
18,0.9,2035977,116102309
17,0.85,275274,116377583
16,0.8,48105,116425688
15,0.75,540777,116966465
14,0.7,876904,117843369
13,0.65,2412893,120256262
12,0.6,130828,120387090
11,0.55,371739,120758829


In [0]:
# Save tuple stats to Azure blob storage
output_container_path = "wasbs://%s@%s.blob.core.windows.net" % (blob_container_name, blob_account_name)
output_blob_folder = "%s/" % output_container_path
output_file_name = 'tuplestatsparquet.csv'
final_file_name = 'tuplestats.txt'
output_filename = output_blob_folder + output_file_name
final_filename = output_blob_folder + final_file_name

dftupledesc \
    .coalesce(1) \
    .write \
    .mode("overwrite") \
    .option("header", "true") \
    .options(delimiter='\t') \
    .format("csv") \
    .save(output_filename)

# Get the name of the CSV file that was just saved to Azure blob storage (it starts with 'part-')
files = dbutils.fs.ls(output_filename)
output_file = [x for x in files if x.name.startswith("part-")]

# Move the wrangled-data CSV file from a sub-folder (wrangled_data_folder) to the root of the blob container
# While simultaneously changing the file name
dbutils.fs.mv(output_file[0].path, final_filename)

# Remove the parquet blob
dbutils.fs.rm(output_filename, recurse=True)

### Use GraphFrames to find connected components
For each stability value, we extract a subgraph from the overall graph G.  The edges of the subgraphs are weighted with at least the stability required during the iteration.  This forms B subgraphs.  For each subgraph, we use the graphframes.connectedComponents() routine to find all of the connected components.  Each connected component is analogous to a cluster of observations.  We find the stability value at which the number of connected components begins to explode -- this is similar to how we determine K in a KMC cluster analysis.  We then take the subgraph whose edges have stability of at least that value and find its connected components.  This step is necessary because until this point we have not save the connected component results.  We finally write the connected component list to Azure ADLS2 for further processing on a lower cost compute resource.

In [0]:
# Create graph from observations and edges
respidGraph = gf.GraphFrame(dfall, dfedges).cache()


In [0]:
display(dfedges)

src,dst,tottuples,stability
201637100920.0,201737554619.0,20,1.0
201637100920.0,201831192363.0,20,1.0
201637100920.0,201645726537.0,13,0.65
201637100920.0,201925271678.0,11,0.55
201637100920.0,201723826514.0,13,0.65
201637100920.0,201639303540.0,11,0.55
201669499330.0,201699372019.0,20,1.0
201669499330.0,201753470655.0,20,1.0
201669499330.0,201973936715.0,11,0.55
201669499330.0,201697435089.0,13,0.65


In [0]:
# Generate table of number of components and size of largest by stability
columns = ['tuplecount', 'numcomps', 'bigcomp', 'bigcompcnt']
vals = []

for i in range(1,21):
    # Restict edges to those that show up at least i times
    filterstr = 'tottuples>=%s' % i   
    respidGraphSubset = respidGraph.filterEdges(filterstr)

    # Generate connected components
    result = respidGraphSubset.connectedComponents()
    result.createOrReplaceTempView('resulttbl')
    
    # Generate stats regarding connected components and append to list
    numcomps = spark.sql("""
    select count(*) from
    (
      select component, count(*) as compcount from resulttbl
      group by component
      order by compcount desc
    )
    """).collect()[0][0]

    dfbigcomp = spark.sql("""
        select component, count(*) as compcount from resulttbl
        group by component
        order by compcount desc
    """).collect()[0][0]

    dfbigcompcount = spark.sql("""
        select component, count(*) as compcount from resulttbl
        group by component
        order by compcount desc
    """).collect()[0][1]

    print(i, numcomps, dfbigcomp, dfbigcompcount)
    vals.append((i, numcomps, dfbigcomp, dfbigcompcount))
    
# Save list to a dataframe
dfcompdesc = spark.createDataFrame(vals, columns)
display(dfcompdesc)

tuplecount,numcomps,bigcomp,bigcompcnt
1,1,0,42887
2,1,0,42887
3,1,0,42887
4,1,0,42887
5,1,0,42887
6,1,0,42887
7,1,0,42887
8,4,2,18906
9,4,2,18906
10,7,2,10216


In [0]:
# Save connected component stats to Azure blob storage
output_container_path = "wasbs://%s@%s.blob.core.windows.net" % (blob_container_name, blob_account_name)
output_blob_folder = "%s/" % output_container_path
output_file_name = 'conncompstatsparquet.csv'
final_file_name = 'conncompstats.txt'
output_filename = output_blob_folder + output_file_name
final_filename = output_blob_folder + final_file_name

dfcompdesc \
    .coalesce(1) \
    .write \
    .mode("overwrite") \
    .option("header", "true") \
    .options(delimiter='\t') \
    .format("csv") \
    .save(output_filename)

# Get the name of the CSV file that was just saved to Azure blob storage (it starts with 'part-')
files = dbutils.fs.ls(output_filename)
output_file = [x for x in files if x.name.startswith("part-")]

# Move the wrangled-data CSV file from a sub-folder (wrangled_data_folder) to the root of the blob container
# While simultaneously changing the file name
dbutils.fs.mv(output_file[0].path, final_filename)

# Remove the parquet blob
dbutils.fs.rm(output_filename, recurse=True)

In [0]:
# Generate connected components for selected stability levels, 0.85 (c=17), AND 0.60 (c=12)

# Restict edges to those that show up at least i times
c = 12
filterstr = 'tottuples>=%s' % c
respidGraphSubset = respidGraph.filterEdges(filterstr)

# Generate connected components
result = respidGraphSubset.connectedComponents()
result.createOrReplaceTempView('resulttbl')
    


In [0]:
saveresult = result.withColumn('AFUVECTSTR', f.concat(f.lit("["), f.concat_ws(",",f.col("AFUVECT")), f.lit("]")))\
                   .drop('AFUVECT')\
                   .withColumnRenamed('AFUVECTSTR', 'AFUVECT')
display(saveresult)

id,YRWEIGHT,CATAG6,SVCFLAG,IRSEX,IRMARIT,NEWRACE2,EDUHIGHCAT,IRWRKSTAT,GOVTPROG,INCOME,COUTYP4,AIIND02,component,AFUVECT
201667830023.0,3497.2385,4,0,1,1,2,3,1,2,2,1,2,9,"[0,991,991,991,991,991,991,991,991,991]"
201669805339.0,217.14131,4,0,1,1,1,1,3,1,1,3,2,4,"[0,15,18,12,991,23,991,991,991,991]"
201735334628.0,3999.8267,5,0,2,1,1,3,2,2,4,1,2,62,"[0,23,16,21,22,991,991,991,991,991]"
201967507703.0,956.413,3,0,1,4,1,4,1,2,3,1,2,4,"[0,16,15,16,991,20,991,991,991,991]"
201653283427.0,1601.6357,4,0,1,1,1,3,1,2,2,1,2,0,"[0,13,24,14,991,991,991,991,991,991]"
201723293808.0,836.19977,5,0,2,1,2,3,1,2,4,2,2,12,"[0,50,30,991,991,991,991,991,991,991]"
201615352127.0,1385.9036,4,0,2,1,1,4,4,2,4,1,2,0,"[0,16,13,18,991,991,991,991,991,991]"
201810559555.0,1187.0437,4,0,2,1,5,4,1,2,4,1,2,9,"[0,991,991,991,991,991,991,991,991,991]"
201813423565.0,674.53644,2,0,2,4,1,1,1,2,1,2,2,14,"[0,11,15,13,18,36,18,18,18,19]"
201891366169.0,5975.273,4,0,2,3,7,4,1,2,4,2,2,0,"[0,16,12,12,991,991,991,991,991,991]"


In [0]:
saveresult.createOrReplaceTempView('saveresulttbl')

In [0]:
%sql
select component, count(*) as compcnt from saveresulttbl
group by component

component,compcnt
0,8993
9,4757
17,1404
5,7435
12,7553
62,1585
2,5119
4,2528
14,1014
21,2498


In [0]:
# Save connected component data to disk
# Save connected component stats to Azure blob storage
output_container_path = "wasbs://%s@%s.blob.core.windows.net" % (blob_container_name, blob_account_name)
output_blob_folder = "%s/" % output_container_path
output_file_name = 'conncomp11parquet.csv'
final_file_name = 'conncomp11.csv'
output_filename = output_blob_folder + output_file_name
final_filename = output_blob_folder + final_file_name

saveresult \
    .coalesce(1) \
    .write \
    .mode("overwrite") \
    .option("header", "true") \
    .options(delimiter='\t') \
    .format("csv") \
    .save(output_filename)

# Get the name of the CSV file that was just saved to Azure blob storage (it starts with 'part-')
files = dbutils.fs.ls(output_filename)
output_file = [x for x in files if x.name.startswith("part-")]

# Move the wrangled-data CSV file from a sub-folder (wrangled_data_folder) to the root of the blob container
# While simultaneously changing the file name
dbutils.fs.mv(output_file[0].path, final_filename)

# Remove the parquet blob
dbutils.fs.rm(output_filename, recurse=True)