<a href="https://colab.research.google.com/github/emartinmorgan/hello-world/blob/master/logpreg_sankey.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Sankey Tutorial](https://towardsdatascience.com/visualizing-in-app-user-journey-using-sankey-diagrams-in-python-8373a7bb2d22)

In [137]:
# Load data
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [138]:
# !pip install chart_studio
import pandas as pd
# import chart_studio as cs
# import chart_studio.plotly as cspy
import plotly.graph_objects as go
import plotly.express as px
import seaborn as se
from google.colab import files
print("hello world")

hello world


In [139]:
all = pd.read_csv("/content/drive/MyDrive/OSU - BMI/Fareed Lab/Thesis/Data_Final/Sankey/sank_all_fin.csv")
all.head()

Unnamed: 0,pat_epi,clus_all_id,clus_all,tri
0,1.1,1.0,Resulters,1
1,3.1,2.0,Result Messagers,1
2,4.2,3.0,Average Users,1
3,5.1,4.0,Average Users,1
4,5.2,5.0,Average Users,1


In [140]:
# Based on the tri, we can compute the rank of each cluster at the pat_epi level:

# a) Sort ascendingly per pat_epi and tri

all.sort_values(['pat_epi', 'tri'],
                 ascending=[True, True], inplace=True)
all.head()

Unnamed: 0,pat_epi,clus_all_id,clus_all,tri
0,1.1,1.0,Resulters,1
10206,1.1,1.0,Result Messagers,2
20412,1.1,1.0,Average Users,3
30618,1.1,1.0,Average Users,4
1,3.1,2.0,Result Messagers,1


In [141]:
# # b) Define a ranking function based on tri, using the method = 'first' param to ensure no clusters have the same rank

# def rank(x): return x['tri'].rank(method='first').astype(int)

# f) Group by user_id 
grouped = all.groupby('pat_epi')

# # c) Apply the ranking function to the all DF into a new "rank_clus" column
# all["rank_clus"] = grouped.apply(rank).reset_index(0, drop=True)

# d) The shift function allows to access the next row's data. Here, we'll want the cluster name
def get_next_clus(x): return x['clus_all'].shift(-1)

# e) Apply the function into a new "next_clus" column
all["next_clus"] = grouped.apply(
     lambda x: get_next_clus(x)).reset_index(0, drop=True)

# f) Regroup by user_id 
grouped = all.groupby('pat_epi')
all.head()

Unnamed: 0,pat_epi,clus_all_id,clus_all,tri,next_clus
0,1.1,1.0,Resulters,1,Result Messagers
10206,1.1,1.0,Result Messagers,2,Average Users
20412,1.1,1.0,Average Users,3,Average Users
30618,1.1,1.0,Average Users,4,
1,3.1,2.0,Result Messagers,1,Resulters


In [142]:
all_clus_at_this_rank

array(['Resulters', 'Result Messagers', 'Average Users', 'Pure Resulters',
       nan], dtype=object)

We attribute a unique color to each event, and name this dict nodes_dict.

In [143]:
# Working on the nodes_dict

all_clus = list(all.clus_all.unique())

# Create a set of colors that you'd like to use in your plot.
palette = ['50BE97', 'E4655C', 'FCC865',
           'BFD6DE', '3E5066', '353A3E', 'E6E6E6']
#  Here, I passed the colors as HEX, but we need to pass it as RGB. This loop will convert from HEX to RGB:
for i, col in enumerate(palette):
    palette[i] = tuple(int(col[i:i+2], 16) for i in (0, 2, 4))

# Append a Seaborn complementary palette to your palette in case you did not provide enough colors to style every event
complementary_palette = se.color_palette(
    "deep", len(all_events) - len(palette))
if len(complementary_palette) > 0:
    palette.extend(complementary_palette)

output = dict()
output.update({'nodes_dict': dict()})

i = 0
for tri in all.tri.unique(): # For each tri of clus...
    # Create a new key equal to the tri...
    output['nodes_dict'].update(
        {tri: dict()}
    )
    
    # Look at all the events that were done at this step of the funnel...
    all_clus_at_this_tri = all[all.tri ==
                                   tri].clus_all.unique()
    
    # Read the colors for these events and store them in a list...
    tri_palette = []
    for clus in all_clus_at_this_tri:
        tri_palette.append(palette[list(all_clus).index(clus)])
    
    # Keep trace of the events' names, colors and indices.
    output['nodes_dict'][tri].update(
        {
            'sources': list(all_clus_at_this_tri),
            'color': tri_palette,
            'sources_index': list(range(i, i+len(all_clus_at_this_tri)))
        }
    )
    # Finally, increment by the length of this rank's available clus to make sure next indices will not be chosen from existing ones
    i += len(output['nodes_dict'][tri]['sources_index'])

In [144]:
all_clus_at_this_tri

array(['Average Users', 'Resulters', 'Result Messagers', 'Schedulers',
       'Non-Users'], dtype=object)

In [145]:
# Working on the links_dict

output.update({'links_dict': dict()})

# Group the DataFrame by pat_epi and tri
grouped = all.groupby(['pat_epi', 'tri'])

# Define a function to read the souces, targets, values clus_all to next_clus:
def update_source_target(user):
    try:
        # user.name[0] is the user's pat_epi; user.name[1] is the tri
        # 1st we retrieve the source and target's indices from nodes_dict
        source_index = output['nodes_dict'][user.name[1]]['sources_index'][output['nodes_dict']
                                                                           [user.name[1]]['sources'].index(user['clus_all'].values[0])]
        target_index = output['nodes_dict'][user.name[1] + 1]['sources_index'][output['nodes_dict']
                                                                               [user.name[1] + 1]['sources'].index(user['next_clus'].values[0])]

         # If this source is already in links_dict...
        if source_index in output['links_dict']:
            # ...and if this target is already associated to this source...
            if target_index in output['links_dict'][source_index]:
                # ...then we increment the count of users with this source/target pair by 1
                output['links_dict'][source_index][target_index]['unique_users'] += 1
            # ...but if the target is not already associated to this source...
            else:
                # ...we create a new key for this target, for this source, and initiate it with 1 user and the time from source to target
                output['links_dict'][source_index].update({target_index:
                                                           dict(
                                                               {'unique_users': 1}
                                                            )
                                                           })
        # ...but if this source isn't already available in the links_dict, we create its key and the key of this source's target, and we initiate it with 1 user and the time from source to target
        else:
            output['links_dict'].update({source_index: dict({target_index: dict(
                {'unique_users': 1})})})
    except Exception as e:
        pass

# Apply the function to your grouped Pandas object:
grouped.apply(lambda user: update_source_target(user)) 



In [167]:
all.head()

Unnamed: 0,pat_epi,clus_all_id,clus_all,tri,next_clus
0,1.1,1.0,Resulters,1,Result Messagers
10206,1.1,1.0,Result Messagers,2,Average Users
20412,1.1,1.0,Average Users,3,Average Users
30618,1.1,1.0,Average Users,4,
1,3.1,2.0,Result Messagers,1,Resulters


In [168]:
all.groupby(['clus_all','tri']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,pat_epi,clus_all_id,next_clus
clus_all,tri,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Average Billers,2,2665,2665,2665
Average Billers,3,1532,1532,1532
Average Users,1,3342,3342,3342
Average Users,3,2593,2593,2593
Average Users,4,3566,3566,0
Non-Users,1,98,0,98
Non-Users,2,27,0,27
Non-Users,3,59,0,59
Non-Users,4,3,0,0
Pure Resulters,1,462,462,462


In [216]:
targets = []
sources = []
values = []

for source_key, source_value in output['links_dict'].items():
    for target_key, target_value in output['links_dict'][source_key].items():
        sources.append(source_key)
        targets.append(target_key)
        values.append(target_value['unique_users'])
        
labels = []
colors = []

for key, value in output['nodes_dict'].items():
    labels = labels + list(output['nodes_dict'][key]['sources']) 
    colors = colors + list(output['nodes_dict'][key]['color'])
    print(value)
for idx, color in enumerate(colors):
    colors[idx] = "rgb" + str(color) + ""


{'sources': ['Resulters', 'Result Messagers', 'Average Users', 'Schedulers', 'Pure Resulters', 'Non-Users'], 'color': [(80, 190, 151), (228, 101, 92), (252, 200, 101), (191, 214, 222), (230, 230, 230), (53, 58, 62)], 'sources_index': [0, 1, 2, 3, 4, 5]}
{'sources': ['Result Messagers', 'Resulters', 'Average Billers', 'Schedulers', 'Non-Users'], 'color': [(228, 101, 92), (80, 190, 151), (62, 80, 102), (191, 214, 222), (53, 58, 62)], 'sources_index': [6, 7, 8, 9, 10]}
{'sources': ['Average Users', 'Schedulers', 'Resulters', 'Result Messagers', 'Average Billers', 'Non-Users'], 'color': [(252, 200, 101), (191, 214, 222), (80, 190, 151), (228, 101, 92), (62, 80, 102), (53, 58, 62)], 'sources_index': [11, 12, 13, 14, 15, 16]}
{'sources': ['Average Users', 'Resulters', 'Result Messagers', 'Schedulers', 'Non-Users'], 'color': [(252, 200, 101), (80, 190, 151), (228, 101, 92), (191, 214, 222), (53, 58, 62)], 'sources_index': [17, 18, 19, 20, 21]}


['Resulters',
 'Result Messagers',
 'Average Users',
 'Schedulers',
 'Pure Resulters',
 'Non-Users',
 'Result Messagers',
 'Resulters',
 'Average Billers',
 'Schedulers',
 'Non-Users',
 'Average Users',
 'Schedulers',
 'Resulters',
 'Result Messagers',
 'Average Billers',
 'Non-Users',
 'Average Users',
 'Resulters',
 'Result Messagers',
 'Schedulers',
 'Non-Users']

In [212]:
cc=all.groupby(['clus_all','tri']).count()
cc=cc.reset_index()
# cc
cc.sort_values(['tri','clus_all'])
a=np.char.array(cc['clus_all'].values)
b=np.char.array(cc['tri'].values)
c=np.char.array(cc['pat_epi'].values)
labels=(a + b' : ' + c).astype(str)
# cc['clus_all']+cc['tri']+cc['pat_epi']
# np.array(cc[['clus_all','tri']]).to_list()

In [212]:
cc

In [213]:
import numpy as np

In [214]:
# label = ["{} {}".format(node1_name, node1_val), "{} {}".format(node2_name, node2_val) ...]
fig = go.Figure(data=[go.Sankey(
    node=dict(
        thickness=20,  # default is 20
        line=dict(color="black", width=0.5),
        label=labels,
        color=colors
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        hovertemplate='%{value} unique pregnancy episodes went from %{source.label} to %{target.label}.<br />',
    ))])

fig.update_layout(autosize=True, title=dict(text="Movement of Clusters During Pregnancy Episode", font_size=16), font=dict(size=12, family="Arial"), plot_bgcolor='white')
fig.show()
