In [None]:
# For additional information please refer to:
    # https://plotly.com/python/sankey-diagram/

## Import packages

In [1]:
import pandas as pd
import csv
import seaborn as sns
import numpy as np
import glob
import os
import matplotlib.pyplot as plt
from matplotlib import rcParams
from io import StringIO
import random
import time
import plotly.graph_objects as go
import copy
import sys
from print_versions import print_versions

In [2]:
print("Python version")
print(sys.version)
print("\nPackages version")
print_versions(globals())

Python version
3.10.11 (main, Dec 23 2024, 23:03:26) [Clang 16.0.0 (clang-1600.0.26.6)]

Packages version
pandas==2.2.3
csv==1.0
seaborn==0.13.2
numpy==1.26.4
matplotlib==3.10.0


## Set up the environment

In [None]:
# Note - In order to make the Sankey diagrams the previous csv file needs to have a specific format:
    # 1. The first columns represent the timepoints of the analysis (this will vary depending on the experimental design). Each cell has the label of the phenotype
    # 2. The following column constains the genotype(s)
    # 3. the last column has the number of samples that present a particular genotype and phenotype development
# I recommend to manually create this CSV file first and then run the following code

In [3]:
input_dir = '/Users/xxxx/Library/CloudStorage/xxxx/' # Working directory where the dataset is saved
file_name = input_dir + 'triples_genotype.csv' # Name of the file containing the dataset
df_sankey = pd.read_csv(file_name, sep=';') # Read the .csv file
df_sankey.head() # Print the first lines of the table

Unnamed: 0,2dpf,3dpf,4dpf,5dpf,"genotype (2a, 2b, crlb)",size
0,2dpf_D,3dpf_D&C1,4dpf_C2,5dpf_C2,"hom,het,hom",1
1,2dpf_D,3dpf_D&C1,4dpf_C2,5dpf_C3,"hom,wt,hom",1
2,2dpf_D,3dpf_D&C1,4dpf_C3,5dpf_C3,"hom,hom,het",11
3,2dpf_D,3dpf_D&C1,4dpf_C2,5dpf_C3,"hom,hom,het",1
4,2dpf_D,3dpf_D&C1,4dpf_C3,5dpf_C3,"hom,hom,hom",5


## Define basic plotting parameters

In [4]:
labels = []
for col in df_sankey.columns[:-2]: # Read the dataframe except the columns with the genotype and the sample counts
    label = df_sankey[col].unique() # For each column get the unique labels
    for i in label:
        labels.append(i) # Append the labels into a new list
labels.sort()

dict_label = {}
for i in enumerate(labels):
    dict_label[i[0]] = i[1] # Transform the list into a dictionary

col_list = []
for i in labels:
    color = "darkgrey" # This parameter will determine the color of each label. By default it is set to be grey
    col_list.append(color)
    # Note: To use different colours for each label.
        # 1) Get the total number of labels, len(labels)
        # 2) Provide a colour list with the same length as in (1), col_list = ['red', 'blue', ...]

## Pre-process the working dataframe

In [5]:
for col in df_sankey.columns[:-2]:  # Read the dataframe except the columns with the genotype and the embryo counts
    df_sankey[col] = df_sankey[col].map({v: k for k, v in dict_label.items()}) # Change each label for the index of the dictionary

df_sankey.head() # Ensure that the dataframe only contains numbers and the genotype as strings

Unnamed: 0,2dpf,3dpf,4dpf,5dpf,"genotype (2a, 2b, crlb)",size
0,0,3,6,10,"hom,het,hom",1
1,0,3,6,11,"hom,wt,hom",1
2,0,3,7,11,"hom,hom,het",11
3,0,3,6,11,"hom,hom,het",1
4,0,3,7,11,"hom,hom,hom",5


In [6]:
genolist = df_sankey['genotype (2a, 2b, crlb)'].unique().tolist() # Get the labels from the genotype column
dpf_number = df_sankey.shape[1]-2 # Get the number of columns containing phenotype information (ignore the genotype and counts columns)

In [7]:
df_storage = []
for col in df_sankey.columns: # For each column of the dataframe
    for i in range(0,df_sankey.shape[0]):
        value = df_sankey._get_value(i,col) # Get the values of the column
        df_storage.append(value) # Store all the values of the dataframe as a unique list

In [8]:
chunks = [df_storage[i:i+df_sankey.shape[0]] for i in range(0, len(df_storage), df_sankey.shape[0])] # Divide the previous list based on the length (rows) of the original dataframe

In [9]:
df_storage2 = []
for i in range(0,len(chunks)):
    if i < len(chunks)-3: # 3 is required to not overwrite the columns with the genotype and the total number count.
                          # In this example there are two columns with additional information. If more columns the number above will be n+1.
        try:
            mydf = pd.DataFrame(list(zip(chunks[i], chunks[i+1],chunks[len(chunks)-1],chunks[len(chunks)-2])))
            mydf = mydf.T
            df_storage2.append(mydf)
        except Exception as e:
            pass
    else:
        pass
df_storage2[:1]

[            0           1            2            3            4           5   \
 0            0           0            0            0            0           0   
 1            3           3            3            3            3           3   
 2            1           1           11            1            5           2   
 3  hom,het,hom  hom,wt,hom  hom,hom,het  hom,hom,het  hom,hom,hom  hom,hom,wt   
 
             6            7            8            9   ...           37  \
 0            1            1            1            1  ...            1   
 1            2            4            4            4  ...            4   
 2            1            1            3            6  ...            8   
 3  hom,het,hom  het,het,hom  het,hom,hom  hom,het,het  ...  het,het,het   
 
            38           39          40          41          42         43  \
 0           1            1           1           1           1          1   
 1           4            4           4           

In [10]:
df_storage3 = pd.concat(df_storage2, axis = 1) # Concatenate the previous dataframes into a single table
colors = plt.get_cmap('tab20').colors * 4 # Takes the colours from a Matplotlib palette
        # More colours can be found: https://matplotlib.org/stable/users/explain/colors/colormaps.html
colors_rgb = {name: (int(r*255), int(g*255), int(b*255)) for name, (r, g, b) in zip(genolist, colors)} # Transform the colours from a Matplotlib into RGB
out = {k: colors_rgb[k] for k in list(colors_rgb)[:5]} # Ensure the dictionary contains a colour for each genotype
out

{'hom,het,hom': (31, 119, 180),
 'hom,wt,hom': (174, 199, 232),
 'hom,hom,het': (255, 127, 14),
 'hom,hom,hom': (255, 187, 120),
 'hom,hom,wt': (44, 160, 44)}

In [None]:
# Colours can also be defined manually
colors_rgb = {'rest': (159, 159, 163), #darkgrey
              'x,het,hom': (144, 187, 228),
              'x,hom,hom': (91, 108, 179),
              'x,wt,hom': (120, 72, 156),
              'hom,het,x': (103, 14, 14),
              'hom,het,hom': (216, 70, 39),
              'hom,hom,x': (0, 167, 157),
              'hom,wt,x': (255, 152, 150),
              'hom,wt,hom': (148, 103, 189),
              }

In [11]:
color_row = [colors_rgb[genotype] for genotype in df_storage3.iloc[3]] # Creates a dataframe with the colours for each genotype
df_storage3.loc[4] = color_row # Store the raw values into the fourth row of the dataframe (this number may need to be changed depending on the shape of the dataframe)
def rgb_to_rgba_string(rgb_tuple):
    return f"rgba({rgb_tuple[0]}, {rgb_tuple[1]}, {rgb_tuple[2]}, 1.0)" # Format the colours for the Sankey plot.
    # The value of 0.8 refers to the opacity of the plot and can be changed for any value.
df_storage3.loc[5] = df_storage3.loc[4].apply(rgb_to_rgba_string) # Store the colours in a rgba format

In [12]:
sankey_df = df_storage3 # Create a copy of the working dataframe ensuring subsequent changes do not change raw results
sankey_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,37,38,39,40,41,42,43,44,45,46
0,0,0,0,0,0,0,1,1,1,1,...,8,8,8,8,8,8,8,8,8,8
1,3,3,3,3,3,3,2,4,4,4,...,12,12,12,12,12,12,12,12,12,12
2,1,1,11,1,5,2,1,1,3,6,...,8,4,5,1,4,3,1,1,4,1
3,"hom,het,hom","hom,wt,hom","hom,hom,het","hom,hom,het","hom,hom,hom","hom,hom,wt","hom,het,hom","het,het,hom","het,hom,hom","hom,het,het",...,"het,het,het","het,het,wt","het,hom,het","het,hom,wt","het,wt,het","wt,het,het","wt,het,wt","wt,hom,het","wt,wt,het","wt,wt,wt"
4,"(31, 119, 180)","(174, 199, 232)","(255, 127, 14)","(255, 127, 14)","(255, 187, 120)","(44, 160, 44)","(31, 119, 180)","(152, 223, 138)","(214, 39, 40)","(255, 152, 150)",...,"(188, 189, 34)","(199, 199, 199)","(219, 219, 141)","(23, 190, 207)","(158, 218, 229)","(31, 119, 180)","(174, 199, 232)","(255, 127, 14)","(255, 187, 120)","(44, 160, 44)"


In [13]:
# This cell checks when there is missing data (e.g. empty cells in the original CSV file) and converts the links between the nodes to 0
index_remove = []
for i in dict_label:
    if 'empty' in dict_label[i]: # Change for the label you want to remove (e.g. 'empty', 'classI', etc.)
        indexDict = list(dict_label.keys()).index(i) # Recognises the index of 'empty' in the original dictionary
        index_remove.append(indexDict) # Stores the index in a variable (done in case multiple classes need to be modified)
    else:
        pass
    
for idremove in index_remove:
    mask_index = sankey_df.loc[0,:]==idremove # Creates a mask in the dataframe to specified the columns that need to be modified
    mask_index2 = sankey_df.loc[1,:]==idremove
    sankey_df.loc[2,mask_index] = 0 # Modifies the value of the third row to a 0. This ensures that no links will be plot in the Sankey diagram between the two nodes specified in rows 1 and 2
    sankey_df.loc[2,mask_index2] = 0

In [14]:
max_col = max((i for i in sankey_df.columns.unique() if isinstance(i, int))) # Checks the maximum column index in the dataframe
print(f"The highest column number in the dataframe is {max_col}")

The highest column number in the dataframe is 46


In [15]:
# Creates an additional column to add into the dataframe to expand the white canvas in the final figure to move the nodes
# If bigger space is needed change 100 to a bigger number
new_row_dict = {}
ll = [max_col+1, max_col+2, 100,'random', (0, 0, 0)]
    # Creates the artificial column including two need nodes, a value of 100 for the link, a 'random' genotype and the link will be plot in white.
for key in range(0,sankey_df.shape[0]-1):
    new_row_dict[str(key)] = ll[key] # Add the new column to the previous dataframe
new_row_dict

{'0': 47, '1': 48, '2': 100, '3': 'random', '4': (0, 0, 0)}

In [16]:
def tuple_to_rgba_string(rgb_tuple, opacity):
    return f"rgba({rgb_tuple[0]}, {rgb_tuple[1]}, {rgb_tuple[2]}, {opacity})" # Transform the RGB colours into strings
rgb_tuple = new_row_dict['4']
opacity = 0.0 # Define the opacity of the RGB colours
new_row_dict['5'] = tuple_to_rgba_string(rgb_tuple, opacity)
print(new_row_dict)

{'0': 47, '1': 48, '2': 100, '3': 'random', '4': (0, 0, 0), '5': 'rgba(0, 0, 0, 0.0)'}


In [None]:
# Important note: The following cells cannot be run multiple times. Changes need to be run once and then re-load the original sankey_df file (sankey_df = df_storage3)

In [17]:
new_column_df = pd.DataFrame.from_dict(new_row_dict, orient='index', columns=['New_Column']) # Converts the previous colour dictionary into a dataframe
new_column_df.index = new_column_df.index.astype(int)
sankey_df = sankey_df.join(new_column_df) # Joins the new column into the main dataset that will be plotted
sankey_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,38,39,40,41,42,43,44,45,46,New_Column
0,0,0,0,0,0,0,1,1,1,1,...,8,8,8,8,8,8,8,8,8,47
1,3,3,3,3,3,3,2,4,4,4,...,12,12,12,12,12,12,12,12,12,48
2,1,1,11,1,5,2,1,1,3,6,...,4,5,1,4,3,1,1,4,1,100
3,"hom,het,hom","hom,wt,hom","hom,hom,het","hom,hom,het","hom,hom,hom","hom,hom,wt","hom,het,hom","het,het,hom","het,hom,hom","hom,het,het",...,"het,het,wt","het,hom,het","het,hom,wt","het,wt,het","wt,het,het","wt,het,wt","wt,hom,het","wt,wt,het","wt,wt,wt",random
4,"(31, 119, 180)","(174, 199, 232)","(255, 127, 14)","(255, 127, 14)","(255, 187, 120)","(44, 160, 44)","(31, 119, 180)","(152, 223, 138)","(214, 39, 40)","(255, 152, 150)",...,"(199, 199, 199)","(219, 219, 141)","(23, 190, 207)","(158, 218, 229)","(31, 119, 180)","(174, 199, 232)","(255, 127, 14)","(255, 187, 120)","(44, 160, 44)","(0, 0, 0)"
5,"rgba(31, 119, 180, 1.0)","rgba(174, 199, 232, 1.0)","rgba(255, 127, 14, 1.0)","rgba(255, 127, 14, 1.0)","rgba(255, 187, 120, 1.0)","rgba(44, 160, 44, 1.0)","rgba(31, 119, 180, 1.0)","rgba(152, 223, 138, 1.0)","rgba(214, 39, 40, 1.0)","rgba(255, 152, 150, 1.0)",...,"rgba(199, 199, 199, 1.0)","rgba(219, 219, 141, 1.0)","rgba(23, 190, 207, 1.0)","rgba(158, 218, 229, 1.0)","rgba(31, 119, 180, 1.0)","rgba(174, 199, 232, 1.0)","rgba(255, 127, 14, 1.0)","rgba(255, 187, 120, 1.0)","rgba(44, 160, 44, 1.0)","rgba(0, 0, 0, 0.0)"


In [18]:
# Recommended to run 
original_labels = labels.copy()
original_col_list = col_list.copy()

In [19]:
# As two new nodes have been included, this cell includes two empty labels for those new nodes and the link in white.
# This cell is primarily for editing the figure, if the labels or colors are changes they will appear in the final figure
for i in range(0,2):
    labels.append('')
    col_list.append('white')
assert len(labels) == len(col_list), "Labels and colors lists must have the same length"
# If the labels and colors list don´t have the same length ensure the correct number of additional labels and colors have been added

## Sankey diagram - Whole dataset

In [20]:
# Create the Sankey object

def Sankey_object(sankey_df):
    data = {'data': [{'type': 'sankey',
                'domain': {'x': [0, 1], 'y': [0, 1]},
                'orientation': 'h',
                'valueformat': '.0f',
                "arrangement": "freeform",
                'node': {
                        'pad': 30,
                        'thickness': 15,
                        'line': {'color': 'white', 'width': 0.0},  # define the appearance of the lines
                        'label': labels, # define the list of labels
                        'color': col_list, # define the colors for each label          
                        },
                'link':{'source': sankey_df.iloc[0], # Defines the source node
                        'target': sankey_df.iloc[1], # Defines the target node
                        'value': sankey_df.iloc[2], # Defines the value (aka width) that connects the previous two nodes
                        'label': sankey_df.iloc[3], # Defines the label of the node
                        'color': sankey_df.iloc[5], # Defines the colour of the line connecting two nodes based on genotype
                }
                }]
        }

    return data

In [21]:
sankey_obj = Sankey_object(sankey_df)

In [22]:
# Plot the Sankey diagram

def Sankey_plot(sankey_obj):
  
  start_time = time.time()

  fig = go.Figure(
    data=[go.Sankey(**sankey_obj['data'][0])]
  )

  fig.update_layout(title_text="Oedema evolution", font_size=10, autosize = False)
  fig.show()
  
  print("\n--- %s seconds ---" % (time.time() - start_time))

  return fig

In [23]:
plot = Sankey_plot(sankey_obj)


--- 0.12483096122741699 seconds ---


In [None]:
title = 'Triples_phenotype_dataset.html'
plot.write_html(title) # Save the Sankey plot as a .html file

## Sankey diagram - Individual samples

In [24]:
genolist.append('blank spacer') # Include a spacer to the original list of genotypes
print(genolist[:5])
print(len(genolist))

['hom,het,hom', 'hom,wt,hom', 'hom,hom,het', 'hom,hom,hom', 'hom,hom,wt']
26


In [25]:
# Define the colours for each genotype. The last one (blank spacer) should always be set as 'rgba(0, 0, 0, 0.0)'
# The length of the following dataframe must be equal to the length of the genolist varaible
genotype_colors = ['rgba(31, 119, 180,1.0)',
                    'rgba(174, 199, 232,1.0)',
                    'rgba(255, 127, 14,1.0)',
                    'rgba(255, 187, 120,1.0)',
                    'rgba(44, 160, 44,1.0)',
                    'rgba(152, 223, 138,1.0)',
                    'rgba(214, 39, 40,1.0)',
                    'rgba(255, 152, 150,1.0)',
                    'rgba(148, 103, 189,1.0)',
                    'rgba(197, 176, 213,1.0)',
                    'rgba(140, 86, 75,1.0)',
                    'rgba(196, 156, 148,1.0)',
                    'rgba(227, 119, 194,1.0)',
                    'rgba(247, 182, 210,1.0)',
                    'rgba(127, 127, 127,1.0)',
                    'rgba(199, 199, 199,1.0)',
                    'rgba(188, 189, 34,1.0)',
                    'rgba(219, 219, 141,1.0)',
                    'rgba(23, 190, 207,1.0)',
                    'rgba(158, 218, 229,1.0)',
                    'rgba(31, 119, 180,1.0)',
                    'rgba(174, 199, 232,1.0)',
                    'rgba(255, 127, 14,1.0)',
                    'rgba(255, 187, 120,1.0)',
                    'rgba(44, 160, 44,1.0)',
                    'rgba(0, 0, 0,0)',
                   ]

In [27]:
assert len(genolist) == len(genotype_colors), "Labels and colors lists must have the same length"
# The list of genotype_colors should have an additional row specifing the colour of the blank spaces

In [28]:
print(sankey_obj['data'][0]['link']['label'].unique().tolist()) # Get the list of all genotypes
geno_color = 'hom,hom,hom' # Define the genotype that will be colored

['hom,het,hom', 'hom,wt,hom', 'hom,hom,het', 'hom,hom,hom', 'hom,hom,wt', 'het,het,hom', 'het,hom,hom', 'hom,het,het', 'hom,het,wt', 'hom,wt,het', 'hom,wt,wt', 'wt,hom,hom', 'het,wt,hom', 'wt,het,hom', 'wt,wt,hom', 'het,het,wt', 'het,het,het', 'het,hom,het', 'het,hom,wt', 'het,wt,het', 'wt,het,het', 'wt,het,wt', 'wt,hom,het', 'wt,wt,het', 'wt,wt,wt', 'random']


In [29]:
# Plot the Sankey diagram

def Sankey_plot(sankey_obj):
  
  start_time = time.time()

  fig = go.Figure(
    data=[go.Sankey(**sankey_obj['data'][0])]
  )

  fig.update_layout(title_text="phenotype development", font_size=10, autosize = False)
  fig.show()
  
  print("\n--- %s seconds ---" % (time.time() - start_time))

  return fig

In [30]:
unique_genotypes = sankey_obj['data'][0]['link']['label'].unique().tolist() # Get the list of all genotypes
geno = genotype_colors

In [31]:
def Sankey_genomes(sankey_obj):

    sankey_geno = copy.deepcopy(sankey_obj) # Creates a copy of the original Sankey object to avoid overwriting
    col1 = ['rgba(169,169,169,0.4)'] * (len(sankey_geno['data'][0]['link']['color'])-1) # Creates a list of uniform colours (e.g. darkgrey = rgba(169,169,169,0.4)) with the length of the Sankey object
    col1.append('rgba(228, 26, 28, 0.0)')
    col1 = pd.Series(col1,index=sankey_geno['data'][0]['link']['color'].index) # Converts the list into a pandas dataframe
    sankey_geno['data'][0]['link']['color'] = col1 # Change the original colours for the uniform colour

    #unique_genotypes = list(set(sankey_geno['data'][0]['link']['label']))
    unique_genotypes = sankey_geno['data'][0]['link']['label'].unique().tolist()
    
    # Note: The length of the following list of colors needs to be equal to the number of genotypes
    geno = genotype_colors

    if len(unique_genotypes) == len(geno):
        print(f"List of colors and genotypes have the same length. Continue to plot")
        genotype_color_map = {genotype: geno[i % len(geno)] for i, genotype in enumerate(unique_genotypes)}
    
        for (label,label_ind) in zip(sankey_geno['data'][0]['link']['label'].tolist(), sankey_geno['data'][0]['link']['label'].index):
            if label == geno_color:
                sankey_geno['data'][0]['link']['color'][label_ind] = genotype_color_map[geno_color]

        plot_geno = Sankey_plot(sankey_geno) # Note: Any changes in the original Sankey_plot function will also affect the outcomes of these plots.

        return plot_geno
    
    else:
        print('The list of colours does not match the number of different genotypes')
    return sankey_geno

In [32]:
out = Sankey_genomes(sankey_obj)

List of colors and genotypes have the same length. Continue to plot



--- 0.014371156692504883 seconds ---


In [None]:
# Print and save all individual genotypes

for geno_color in sankey_obj['data'][0]['link']['label'].unique().tolist():
    if geno_color != 'random':
        out = Sankey_genomes(sankey_obj)
        title = geno_color + '.html'
        out.write_html(title)
    else:
        pass