### Import required libraries

In [1]:
import pandas as pd
import numpy as np
import random
from collections import Counter
import plotly.graph_objects as go



### Read and clean up missing data

In [2]:
# read data from file
df = pd.read_excel('clustertransition.xls')
# remove rows with missing data
df = df.dropna()
# cast all columns to integer type
df = df.astype('int32')
# first 10 rows
df.head(10)

Unnamed: 0,PID,cluster1,cluster2,cluster3,cluster4
1,1002,4,7,4,3
2,1003,4,5,4,4
5,1006,4,5,4,3
7,1010,2,6,4,4
14,1023,4,1,1,7
15,1026,1,2,6,6
17,1028,2,2,6,2
19,1031,1,1,2,1
21,1033,6,6,3,7
22,1034,1,6,1,1





### Convert dataframe to the right format for sankey diagram

* You will need lists to labels, sources, targets, and values. Colors are optional

In [3]:
# build counter of each unique node e.g. {'4': 19, '2': 9, '1': 6, '6': 6, '3': 5 }
node_counters = []
for i in range(2, 6): # cluster1 to cluster4
    selected_cols = df.columns[1:i]
    node_list = []
    for index, row in df.iterrows():
        node = ""
        for col in selected_cols:
            node += str(row[col])
        node_list.append(node)
    node_counters.append(Counter(node_list))

labels = []
sources = []
targets = []
values = []

for c in node_counters:
    for k in c.keys():
        labels.append(k)
labels.sort()

for c in node_counters[1:]:
    for key, value in c.items():
        # analyze each key to find source, target and value of each connection
        # for example a key '454' with value 5 means 5 patients moved from group
        # '45' to group '454'
        source = key[:-1] # extract '45' from '454'
        target = key
        sources.append(labels.index(source))
        targets.append(labels.index(target))
        values.append(value)

palette = ['deepskyblue', 'darkorange', 'plum', 'lime', 'cyan', 'sandybrown', 'teal']
# generate color by the first state
colors = [palette[int(l[0])-1] for l in labels]



### Display diagram

In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 2,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = labels,
        color = colors
    ),
    link = dict(
        source = sources,
        target = targets,
        value = values
    ))])

fig.update_layout(title_text="Cluster Transition", font_size=10, width=1000, height=800)
fig.show()

![](chart.png)