In [155]:
CT_LIST = ['ESC', 'MES', 'CP', 'CM']
HM_LIST = ['H3K4me3', 'H3K27ac', 'H3K27me3',  'RNA']
PREFIXES = [HM + '_' + CT for HM in HM_LIST for CT in CT_LIST]


MARKER_GENES_EXT = {'ESC': ['Nanog','Pou5f1','Sox2','L1td1','Dppa5a','Tdh','Esrrb','Lefty1','Zfp42','Sfn','Lncenc1','Utf1'],
                    'MES': ['Mesp1','Mesp2','T', 'Vrtn','Dll3','Dll1', 'Evx1','Cxcr4','Pcdh8','Pcdh19','Robo3','Slit1'],
                    'CP':  ['Sfrp5', 'Gata5', 'Tek','Hbb-bh1','Hba-x', 'Pyy','Sox18','Lyl1','Rgs4','Igsf11','Tlx1','Ctse'],
                    'CM':  ['Nppa','Gipr', 'Actn2', 'Coro6', 'Col3a1', 'Bgn','Myh6','Myh7','Tnni3','Hspb7' ,'Igfbp7','Ndrg2'],
                    }



HM_COL_DICT = {'H3K4me3': '#f37654','H3K27ac': '#b62a77','H3K27me3': '#39A8AC','RNA':'#ED455C'}
CT_COL_DICT= {'ESC': '#405074',
                'MES': '#7d5185',
                'CP': '#c36171',
                'CM': '#eea98d',}
SET_COL_DICT= {'training':'#97DA58','validation':'#9b58da','test':'#DA5A58'}
GONZALEZ_COL_DICT= {'Active': '#E5AA44','Bivalent': '#7442BE'}

In [156]:
import plotly.express as px
import pandas as pd
import pickle



# Load gene cluster dictionary
with open(f'./data/gene_clusters_dict.pkl', 'rb') as f:
    GENE_CLUSTERS = pickle.load(f)

# Load CODE and LOG matrices
CODE = pd.read_csv(f'./data/CODE.csv', index_col='GENE')
CODE = CODE.iloc[:, 18:]
CODE.drop(columns=['AE_RMSE','AE_Sc', 'PCA_RMSE', 'PCA_Sc', 'UMAP_RMSE', 'UMAP_Sc'],inplace=True)
LOG = pd.read_csv(f'./data/ALL_X_FC.csv').set_index('GENE')


In [157]:
name = 'TOP'
N_TOP = 4000
with open(f'./data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    CV = pickle.load(f)
name = 'BOTTOM'
with open(f'./data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    BOTTOM_CV = pickle.load(f)
    
STABLE = [gene for gene_list in BOTTOM_CV.values() for gene in gene_list]
CV['STABLE'] = STABLE

CV_MAP={}
for CV_CAT, gene_list in CV.items():
    for gene in gene_list:
        CV_MAP[gene] = CV_CAT
        
CODE["CV_Category"] = CODE.index.map(CV_MAP)
CODE['CV_Category'] = CODE['CV_Category'].fillna('other')

In [158]:
with open('./data/gonzalez_dict.pkl', 'rb') as f:
    GONZALEZ = pickle.load(f)
GONZALEZ_MAP={}
for GONZALEZ_CAT, gene_list in GONZALEZ.items():
    for gene in gene_list:
        GONZALEZ_MAP[gene] = GONZALEZ_CAT
CODE["ESC_ChromState_Gonzalez2021"] = CODE.index.map(GONZALEZ_MAP)
CODE['ESC_ChromState_Gonzalez2021'] = CODE['ESC_ChromState_Gonzalez2021'].fillna('other')
CODE['ESC_ChromState_Gonzalez2021'].value_counts()

ESC_ChromState_Gonzalez2021
Active      9186
other       3495
Bivalent    2315
Name: count, dtype: int64

In [159]:

# Map cluster IDs to CODE and LOG
gene_to_cluster = {}
for cluster_id, gene_list in GENE_CLUSTERS.items():
    for gene in gene_list['gene_list']:
        gene_to_cluster[gene] = cluster_id
CODE["Cluster"] = CODE.index.map(gene_to_cluster).astype(int)


In [160]:
CODE = CODE[['RNA_ESC', 'RNA_MES', 'RNA_CP', 'RNA_CM',

    'H3K4me3_ESC','H3K4me3_MES', 'H3K4me3_CP', 'H3K4me3_CM', 'H3K27ac_ESC', 'H3K27ac_MES',
        'H3K27ac_CP', 'H3K27ac_CM', 'H3K27me3_ESC', 'H3K27me3_MES',
        'H3K27me3_CP', 'H3K27me3_CM', 
        
        'RNA_CM_CP_FC', 'RNA_CM_MES_FC', 'RNA_CM_ESC_FC', 'RNA_CP_MES_FC','RNA_CP_ESC_FC', 'RNA_MES_ESC_FC', 
        'VAE_RMSE', 'VAE_Sc', 
        'RNA_CV','CV_Category', 'ESC_ChromState_Gonzalez2021','Cluster',
        'VAE1', 'VAE2', 'VAE3', 'VAE4', 'VAE5', 'VAE6', 'VAE_UMAP1', 'VAE_UMAP2',]]
CODE['Cluster'] = pd.Categorical(CODE['Cluster'])

In [161]:
RNA_FPKM= pd.read_csv(f'./data/RNA_FPKMs.csv', index_col='GENE')
assert list(RNA_FPKM.index) == list(CODE.index)
#concatenate the two dataframes
DATA = pd.concat([CODE, RNA_FPKM], axis=1)
DATA.to_csv(f'./data/DATA.csv')

Z_AVG_features = ['RNA_ESC', 'RNA_MES', 'RNA_CP', 'RNA_CM', 'H3K4me3_ESC', 'H3K4me3_MES',
        'H3K4me3_CP', 'H3K4me3_CM', 'H3K27ac_ESC', 'H3K27ac_MES', 'H3K27ac_CP',
        'H3K27ac_CM', 'H3K27me3_ESC', 'H3K27me3_MES', 'H3K27me3_CP',
        'H3K27me3_CM']
LOG_FC_features = ['RNA_CM_CP_FC', 'RNA_CM_MES_FC', 'RNA_CM_ESC_FC',
            'RNA_CP_MES_FC', 'RNA_CP_ESC_FC', 'RNA_MES_ESC_FC']

MISC_features = [ 'VAE_RMSE', 'VAE_Sc', 'RNA_CV', 'CV_Category', 'ESC_ChromState_Gonzalez2021', 'Cluster']

LATENT_features = ['VAE1', 'VAE2', 'VAE3', 'VAE4', 'VAE5', 'VAE6', 'VAE_UMAP1', 'VAE_UMAP2']

FPKM_features = [ 'RNA_ESC_1', 'RNA_ESC_2', 'RNA_MES_1', 'RNA_MES_2',
            'RNA_CP_1', 'RNA_CP_2', 'RNA_CM_1', 'RNA_CM_2']


Z_AVG = DATA[Z_AVG_features]

LOG_FC = DATA[LOG_FC_features]

MISC = DATA[MISC_features]

LATENT = DATA[LATENT_features]

FPKM = DATA[FPKM_features]


In [162]:
def gene_trend(MAIN,GENE_LIST,CT_LIST,CT_COL_DICT,Y_LAB='FPKMs'):
    num_genes = len(GENE_LIST)
    grid_size = math.ceil(math.sqrt(num_genes))
    
    plt.figure(figsize=(grid_size*3, grid_size*3))

    for i,GENE_NAME in enumerate(GENE_LIST):
        CT_REG = '|'.join(CT_LIST)
        Series = MAIN.loc[GENE_NAME]
        CT = Series.index.str.extract(f'({CT_REG})')[0]
        REP = Series.index.str.extract(f'(\d)')[0]
        DF = pd.DataFrame({Y_LAB:Series.values, 'CT':CT, 'REP':REP})
        #
        #plt.figure(figsize=(3,3))
        plt.subplot(grid_size, grid_size, i+1)

        plt.title(GENE_NAME)
        sns.stripplot(data=DF,x='CT',y=Y_LAB,hue='CT',palette=CT_COL_DICT,
                    s=12, alpha=1, legend=False,linewidth=0)
        sns.lineplot(data=DF,x='CT',y=Y_LAB, err_style=None,
                    color='black',linewidth=1, dashes=(2, 2))
        sns.despine(left=1,bottom=0,top=1)
        plt.xlabel('')
        plt.xticks([])
        plt.yticks(  np.ceil( [ 0, max(DF[Y_LAB]) ])  )
        plt.ylim(np.ceil([0,  max(DF[Y_LAB])*1.1]))
        # Plot y-axis label only for the first column in each row
        if i % grid_size != 0:  plt.ylabel('')
        plt.tight_layout()
    


invalid escape sequence '\d'


invalid escape sequence '\d'


invalid escape sequence '\d'



In [None]:
FPKM = DATA.filter(FPKM_features)


import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import math

def plot_gene_trend(DATA, SEL_GENES, CT_LIST, CT_COL_DICT, Y_LAB="FPKMs"):
    """
    Plot trends for selected genes across conditions using Plotly.

    Parameters:
        DATA (pd.DataFrame): The data where rows are genes and columns are RNA counts for different conditions/reps.
        SEL_GENES (list): List of gene names to plot.
        CT_LIST (list): List of conditions to extract.
        CT_COL_DICT (dict): Mapping of condition names to colors.
        Y_LAB (str): Label for the Y-axis (e.g., FPKMs).

    Returns:
        Plotly figure with subplots for each gene's trend.
    """
    num_genes = len(SEL_GENES)
    grid_size = math.ceil(math.sqrt(num_genes))  # Create a square grid layout

    # Create subplot figure
    fig = make_subplots(
        rows=grid_size, cols=grid_size,
        subplot_titles=SEL_GENES,
        horizontal_spacing=0.05, vertical_spacing=0.1
    )

    for i, gene_name in enumerate(SEL_GENES):
        # Extract gene data
        gene_data = DATA.loc[gene_name]

        # Extract condition (CT) and replicate (REP) information
        CT_REG = '|'.join(CT_LIST)
        CT = gene_data.index.str.extract(f'({CT_REG})')[0]
        REP = gene_data.index.str.extract(r'(\d)')[0]
        df = pd.DataFrame({Y_LAB: gene_data.values, 'CT': CT, 'REP': REP})

        # Filter out invalid rows
        df = df.dropna()

        # Determine subplot location
        row = (i // grid_size) + 1
        col = (i % grid_size) + 1

        # Add scatter plot for individual points
        for ct in CT_LIST:
            ct_data = df[df['CT'] == ct]
            fig.add_trace(
                go.Scatter(
                    x=ct_data['CT'],
                    y=ct_data[Y_LAB],
                    mode='markers',
                    marker=dict(size=8, color=CT_COL_DICT[ct]),
                    name=ct,
                    showlegend=False,  # Avoid duplicating legends
                    hovertemplate=f"<b>{ct}</b><br>{Y_LAB}: %{Y_LAB}<extra></extra>"
                ),
                row=row, col=col
            )

        # Add line plot for trend
        fig.add_trace(
            go.Scatter(
                x=df['CT'], y=df.groupby('CT')[Y_LAB].mean(),
                mode='lines',
                line=dict(color='black', dash='dash', width=1),
                showlegend=False,
                hoverinfo='skip'
            ),
            row=row, col=col
        )

        # Update subplot axes
        fig.update_xaxes(title_text="", row=row, col=col, showticklabels=False)
        fig.update_yaxes(title_text=Y_LAB if col == 1 else "", row=row, col=col)

    # Update figure layout
    fig.update_layout(
        height=300 * grid_size, width=300 * grid_size,
        title_text="Gene Trends Across Conditions",
        showlegend=False,
        plot_bgcolor="white",
    )

    return fig

plot_gene_trend(FPKM, ['Myh6','Myh7'], CT_LIST, CT_COL_DICT, Y_LAB="FPKMs")
