In [179]:
CT_LIST = ['ESC', 'MES', 'CP', 'CM']
HM_LIST = ['H3K4me3', 'H3K27ac', 'H3K27me3',  'RNA']
PREFIXES = [HM + '_' + CT for HM in HM_LIST for CT in CT_LIST]


MARKER_GENES_EXT = {'ESC': ['Nanog','Pou5f1','Sox2','L1td1','Dppa5a','Tdh','Esrrb','Lefty1','Zfp42','Sfn','Lncenc1','Utf1'],
                    'MES': ['Mesp1','Mesp2','T', 'Vrtn','Dll3','Dll1', 'Evx1','Cxcr4','Pcdh8','Pcdh19','Robo3','Slit1'],
                    'CP':  ['Sfrp5', 'Gata5', 'Tek','Hbb-bh1','Hba-x', 'Pyy','Sox18','Lyl1','Rgs4','Igsf11','Tlx1','Ctse'],
                    'CM':  ['Nppa','Gipr', 'Actn2', 'Coro6', 'Col3a1', 'Bgn','Myh6','Myh7','Tnni3','Hspb7' ,'Igfbp7','Ndrg2'],
                    }



HM_COL_DICT = {'H3K4me3': '#f37654','H3K27ac': '#b62a77','H3K27me3': '#39A8AC','RNA':'#ED455C'}
CT_COL_DICT= {'ESC': '#405074',
                'MES': '#7d5185',
                'CP': '#c36171',
                'CM': '#eea98d',}
SET_COL_DICT= {'training':'#97DA58','validation':'#9b58da','test':'#DA5A58'}
GONZALEZ_COL_DICT= {'Active': '#E5AA44','Bivalent': '#7442BE'}

In [180]:
import plotly.express as px
import pandas as pd
import pickle



# Load gene cluster dictionary
with open(f'./data/gene_clusters_dict.pkl', 'rb') as f:
    GENE_CLUSTERS = pickle.load(f)

# Load CODE and LOG matrices
CODE = pd.read_csv(f'./data/CODE.csv', index_col='GENE')
CODE = CODE.iloc[:, 18:]
CODE.drop(columns=['AE_RMSE','AE_Sc', 'PCA_RMSE', 'PCA_Sc', 'UMAP_RMSE', 'UMAP_Sc'],inplace=True)
LOG = pd.read_csv(f'./data/ALL_X_FC.csv').set_index('GENE')


In [181]:
name = 'TOP'
N_TOP = 4000
with open(f'./data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    CV = pickle.load(f)
name = 'BOTTOM'
with open(f'./data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    BOTTOM_CV = pickle.load(f)
    
STABLE = [gene for gene_list in BOTTOM_CV.values() for gene in gene_list]
CV['STABLE'] = STABLE

CV_MAP={}
for CV_CAT, gene_list in CV.items():
    for gene in gene_list:
        CV_MAP[gene] = CV_CAT
        
CODE["CV_Category"] = CODE.index.map(CV_MAP)
CODE['CV_Category'] = CODE['CV_Category'].fillna('other')

In [182]:
with open('./data/gonzalez_dict.pkl', 'rb') as f:
    GONZALEZ = pickle.load(f)
GONZALEZ_MAP={}
for GONZALEZ_CAT, gene_list in GONZALEZ.items():
    for gene in gene_list:
        GONZALEZ_MAP[gene] = GONZALEZ_CAT
CODE["ESC_ChromState_Gonzalez2021"] = CODE.index.map(GONZALEZ_MAP)
CODE['ESC_ChromState_Gonzalez2021'] = CODE['ESC_ChromState_Gonzalez2021'].fillna('other')
CODE['ESC_ChromState_Gonzalez2021'].value_counts()

ESC_ChromState_Gonzalez2021
Active      9186
other       3495
Bivalent    2315
Name: count, dtype: int64

In [183]:

# Map cluster IDs to CODE and LOG
gene_to_cluster = {}
for cluster_id, gene_list in GENE_CLUSTERS.items():
    for gene in gene_list['gene_list']:
        gene_to_cluster[gene] = cluster_id
CODE["Cluster"] = CODE.index.map(gene_to_cluster).astype(int)


In [184]:
CODE = CODE[['Cluster','RNA_CV','CV_Category', 'ESC_ChromState_Gonzalez2021',
                'RNA_ESC', 'RNA_MES', 'RNA_CP', 'RNA_CM',

    'H3K4me3_ESC','H3K4me3_MES', 'H3K4me3_CP', 'H3K4me3_CM', 'H3K27ac_ESC', 'H3K27ac_MES',
        'H3K27ac_CP', 'H3K27ac_CM', 'H3K27me3_ESC', 'H3K27me3_MES',
        'H3K27me3_CP', 'H3K27me3_CM', 
        
        'RNA_CM_CP_FC', 'RNA_CM_MES_FC', 'RNA_CM_ESC_FC', 'RNA_CP_MES_FC','RNA_CP_ESC_FC', 'RNA_MES_ESC_FC', 
        'VAE_RMSE', 'VAE_Sc', 
        
        'VAE1', 'VAE2', 'VAE3', 'VAE4', 'VAE5', 'VAE6', 'VAE_UMAP1', 'VAE_UMAP2',]]
CODE['Cluster'] = pd.Categorical(CODE['Cluster'])

In [185]:
RNA_FPKM= pd.read_csv(f'./data/RNA_FPKMs.csv', index_col='GENE')
assert list(RNA_FPKM.index) == list(CODE.index)
#concatenate the two dataframes
DATA = pd.concat([CODE, RNA_FPKM], axis=1)
DATA.to_csv(f'./data/DATA.csv')

Z_AVG_features = ['RNA_ESC', 'RNA_MES', 'RNA_CP', 'RNA_CM', 'H3K4me3_ESC', 'H3K4me3_MES',
        'H3K4me3_CP', 'H3K4me3_CM', 'H3K27ac_ESC', 'H3K27ac_MES', 'H3K27ac_CP',
        'H3K27ac_CM', 'H3K27me3_ESC', 'H3K27me3_MES', 'H3K27me3_CP',
        'H3K27me3_CM']
LOG_FC_features = ['RNA_CM_CP_FC', 'RNA_CM_MES_FC', 'RNA_CM_ESC_FC',
            'RNA_CP_MES_FC', 'RNA_CP_ESC_FC', 'RNA_MES_ESC_FC']

MISC_features = [ 'VAE_RMSE', 'VAE_Sc', 'RNA_CV', 'CV_Category', 'ESC_ChromState_Gonzalez2021', 'Cluster']

LATENT_features = ['VAE1', 'VAE2', 'VAE3', 'VAE4', 'VAE5', 'VAE6', 'VAE_UMAP1', 'VAE_UMAP2']

FPKM_features = [ 'RNA_ESC_1', 'RNA_ESC_2', 'RNA_MES_1', 'RNA_MES_2',
            'RNA_CP_1', 'RNA_CP_2', 'RNA_CM_1', 'RNA_CM_2']


Z_AVG = DATA[Z_AVG_features]

LOG_FC = DATA[LOG_FC_features]

MISC = DATA[MISC_features]

LATENT = DATA[LATENT_features]

FPKM = DATA[FPKM_features]


In [119]:
FPKM = DATA.filter(FPKM_features)


import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd



In [175]:
def plot_sankey(DATA, SEL_GENES, FEATURE_CAT, COL_DICT, font_size=12, font_color="black", font_family="Arial", link_opacity=0.5):
    """
    Create a Sankey diagram with FEATURE_CAT as the first layer, genes as the second, and clusters as the last.
    Links between FEATURE_CAT and clusters are colored based on FEATURE_CAT.

    Parameters:
    - DATA (pd.DataFrame): Input data with genes as index.
    - SEL_GENES (list): List of selected gene names.
    - FEATURE_CAT (str): The column name for categorical features to add as the first layer.
    - COL_DICT (dict): A dictionary mapping FEATURE_CAT values to colors in hex format.
    - font_size (int): Font size for node labels.
    - font_color (str): Font color for node labels.
    - font_family (str): Font family for node labels.
    - link_opacity (float): Opacity for the links (0.0 to 1.0).

    Returns:
    - fig: A Plotly Sankey figure.
    """
    import matplotlib.colors as mcolors

    if FEATURE_CAT not in DATA.columns:
        raise ValueError(f"The feature '{FEATURE_CAT}' is not in the provided DATA.")

    # Filter for selected genes
    data_filtered = DATA.loc[SEL_GENES]

    # Create node labels
    feature_nodes = data_filtered[FEATURE_CAT].unique().tolist()  # FEATURE_CAT as the first level
    gene_nodes = SEL_GENES  # Gene names as the second level
    cluster_nodes = data_filtered["Cluster"].unique().tolist()  # Clusters as the last level

    all_nodes = feature_nodes + gene_nodes + cluster_nodes  # Combine all nodes
    node_map = {node: i for i, node in enumerate(all_nodes)}  # Map node name to index

    # Create links
    links = []
    link_colors = []

    # Links from FEATURE_CAT to genes
    for gene, row in data_filtered.iterrows():
        feature_value = row[FEATURE_CAT]
        links.append({
            "source": node_map[feature_value],
            "target": node_map[gene],
            "value": 1  # Equal weight for all links
        })
        link_colors.append("rgba(192,192,192,0.3)")  # Light gray for FEATURE_CAT-to-gene links

    # Links from genes to clusters
    for gene, row in data_filtered.iterrows():
        cluster = row["Cluster"]
        feature_value = row[FEATURE_CAT]
        hex_color = COL_DICT.get(feature_value, "#808080")  # Default to gray
        rgba_color = mcolors.to_rgba(hex_color, alpha=link_opacity)  # Convert hex to RGBA
        rgba_str = f"rgba({int(rgba_color[0]*255)}, {int(rgba_color[1]*255)}, {int(rgba_color[2]*255)}, {rgba_color[3]})"
        links.append({
            "source": node_map[gene],
            "target": node_map[cluster],
            "value": 1  # Equal weight for all links
        })
        link_colors.append(rgba_str)

    # Define node colors
    node_colors = []
    for node in all_nodes:
        if node in feature_nodes:  # FEATURE_CAT nodes
            node_colors.append(COL_DICT.get(node, "#808080"))  # Color by FEATURE_CAT
        elif node in gene_nodes:  # Gene nodes
            node_colors.append("silver")  # Default gray for genes
        else:  # Cluster nodes
            node_colors.append("silver")  # Light blue for clusters

    # Create Sankey diagram
    fig = go.Figure(go.Sankey(
        textfont=dict(size=font_size, color=font_color,shadow=None) ,
        node=dict(
            pad=5,
            thickness=10,
            line=dict(color="black", width=0.5),
            label=all_nodes,
            color=node_colors,
            hovertemplate='%{label}<extra></extra>',
            
            
        ),
        link=dict(
            source=[link["source"] for link in links],
            target=[link["target"] for link in links],
            value=[link["value"] for link in links],
            color=link_colors,# Colored links between FEATURE_CAT and clusters
            hovertemplate='Gene: %{source.label}<br>Cluster: %{target.label}',
        )
    ))


    # Update layout
    fig.update_layout(
        title_text="",
        margin=dict(t=50, l=25, r=25, b=25),
    )
    # Add titles for layers
    fig.add_annotation(x=0.02, y=1.1, text="Category", showarrow=False, font=dict(size=16))
    fig.add_annotation(x=0.5, y=1.1, text="Gene", showarrow=False, font=dict(size=16))
    fig.add_annotation(x=0.98, y=1.1, text="Cluster", showarrow=False, font=dict(size=16), xanchor='right')

    return fig


In [177]:
SEL_GENES=['Nanog','Pou5f1','Sox2','Dppa5a','Mesp1','T', 'Vrtn','Dll3','Gata5', 'Tek','Sox18','Lyl1','Actn2', 'Coro6','Myh6','Myh7']
CAT_FEATURE = 'CV_Category'
COL_DICT =  {'RNA_ESC': '#405074',
                'RNA_MES': '#7d5185',
                'RNA_CP': '#c36171',
                'RNA_CM': '#eea98d',
                'STABLE':'#B4CD70',
                'other':'#ECECEC'}




fig = plot_sankey(DATA, SEL_GENES, CAT_FEATURE, COL_DICT, font_color='white', font_size=14,link_opacity=0.9)
fig


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed