### Import required libraries

In [3]:
import pandas as pd
import numpy as np
import random
from collections import Counter
import plotly.graph_objects as go



### Read and clean up missing data

In [4]:
# read data from file
df = pd.read_excel('clustertransition.xls')
# print first 10 rows
df.head(10)

Unnamed: 0,PID,cluster1,cluster2,cluster3,cluster4
0,1001,,,5.0,4.0
1,1002,4.0,7.0,4.0,3.0
2,1003,4.0,5.0,4.0,4.0
3,1004,7.0,5.0,6.0,
4,1005,,1.0,4.0,4.0
5,1006,4.0,5.0,4.0,3.0
6,1007,6.0,,1.0,5.0
7,1010,2.0,6.0,4.0,4.0
8,1011,3.0,3.0,,
9,1012,4.0,3.0,,4.0





### Convert dataframe to the right format for sankey diagram

* You will need lists to labels, sources, targets, and values. Colors are optional

In [6]:
labels = []
for _ in range(4):
    for i in range(1, 8):
        labels.append(str(i)) # labels = ['1', '2' ... '7', '1', '2' ... '7']
sources = []
targets = []
values = []

for i, src_col in enumerate(df.columns[1:4]):
    tar_col = df.columns[i + 2]
    connections = [] # list of connections at this (src_col, tar_col) pair
    for j in range(df.shape[0]):
        src_node = df.loc[j, src_col]
        tar_node = df.loc[j, tar_col]
        if not (np.isnan(src_node) or np.isnan(tar_node)): # ignore connections with nan
            connections.append(str(int(src_node)) + str(int(tar_node)))
    counter = Counter(connections) # count each unique connections per level i
    for c in counter.keys(): # build sources, targets and values arrays
        sources.append(int(c[0]) - 1 + i * 7)
        targets.append(int(c[1]) - 1 + (i + 1) * 7)
        values.append(counter[c])

palette = ['deepskyblue', 'darkorange', 'plum', 'lime', 'cyan', 'sandybrown', 'teal']
colors = [palette[int(l[0])-1] for l in labels]



### Display diagram

In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 2,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = labels,
        color = colors
    ),
    link = dict(
        source = sources,
        target = targets,
        value = values
    ))])

fig.update_layout(title_text="Cluster Transition", font_size=10, width=1000, height=800)
fig.show()

![](chart.png)