### Set parameters and names

In [None]:
from CustomObjects import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'


In [None]:
INPUT_GENES    = 'ALL'
INPUT_FEATURES = 'X_FC'
INPUT_NORM     = 'z'
CODINGS_SIZE = 6

ID     = f'{CODINGS_SIZE}D_{INPUT_GENES}_{INPUT_FEATURES}_{INPUT_NORM}'
METHOD = 'VAE'
k      = 80
LABELS_COL = f'GMM_{METHOD}_{k}'
DIR_DATA= f'../data/{ID}_analysis/{LABELS_COL}/'
DIR_FIG = f'../figures/{ID}_analysis/{METHOD}/Clusters_k{k}'
! mkdir -p {DIR_FIG}
# set cmap
GMM_LABELS = f'GMM_{METHOD}_{k}'


In [None]:
with open(f'{DIR_DATA}gene_clusters_dict.pkl', 'rb') as f:
    GENE_CLUSTERS = pickle.load(f)


In [None]:

def cluster_intersection(CAT_DICT, gene_clusters):
    # Create a DataFrame with cluster ids as rows and cell types as columns
    cluster_ids = [int(cluster_id) for cluster_id in list(gene_clusters.keys())]
    CATEGORIES = list(CAT_DICT.keys())
    df = pd.DataFrame(0, index=cluster_ids, columns=CATEGORIES)

    # Populate the DataFrame with counts of marker genes in each cluster
    for CAT, GENE_LIST in CAT_DICT.items():
        GENE_LIST = set(GENE_LIST)
        
        for cluster_id, cluster_info in gene_clusters.items():
            cluster_genes = set(cluster_info['gene_list'])
            intersection = GENE_LIST.intersection(cluster_genes)
            
            # Update the DataFrame with the count of marker genes in the intersection
            df.at[int(cluster_id), CAT] = len(intersection)
    
    return df



In [None]:

def plot_barplots_clusters(df, COL_DICT, TITLE=None, Y_LAB = 'Counts', ADD_LABELS=False):
    # Get unique cluster labels
    unique_clusters = df.index.unique()
    num_clusters = len(unique_clusters)

    # Set up subplots
    #fig, axes = plt.subplots(num_clusters // 5 + 1, 5, figsize=(6, 4 + (num_clusters // 5) * 1.5))
    fig, axes = plt.subplots(num_clusters // 8 , 8, figsize=( 4 + ((num_clusters // 5) * 0.75),12) )
    
    plt.suptitle(TITLE, fontsize=20, y=1.02)
    
    # Iterate over each cluster
    for idx, cluster in enumerate(unique_clusters):
        # Get the row index and column index for subplots
        col_index = idx % 8
        row_index = idx // 8

        # Select the data for the current cluster
        cluster_data = df.iloc[cluster]
        
        # Create barplot
        ax = sns.barplot(ax=axes[row_index, col_index], x=cluster_data.index, y=cluster_data.values, hue=cluster_data.index,
                            palette=COL_DICT, saturation=1)
        
        if ADD_LABELS:# Add the count values over each bar if > 0
            for p in ax.patches:
                height = p.get_height()
                if height > 0:
                    ax.annotate(f'{int(height)}', 
                                (p.get_x() + p.get_width() / 2., height), 
                                ha='center', va='center', 
                                xytext=(0, 5), 
                                textcoords='offset points')

        sns.despine()
        
        # Set title and labels
        axes[row_index, col_index].set_title(f'C{cluster}', size=14)
        
        # Set y limits equal for all plots
        max_value = df.max().max()

        axes[row_index, col_index].set_ylim(0, max_value)  # Adjust the limits as per your data range
        
        axes[row_index, col_index].set_yticks([max_value])
        axes[row_index, col_index].set_yticklabels([f'{int(max_value)}'],fontsize=12)
        
        # Hide horizontal grid lines
        axes[row_index, col_index].yaxis.grid(False)

        if row_index == (num_clusters // 5):
            axes[row_index, col_index].set_xlabel('Cell Type', fontsize=12)
        else:        
            axes[row_index, col_index].set_xticklabels('')
            axes[row_index, col_index].set_xlabel('')
        
        # Remove y-axis label for all but the first plot
        if col_index != 0: 
            axes[row_index, col_index].set_yticklabels('')
            axes[row_index, col_index].set_ylabel('')
        else:
            axes[row_index, col_index].set_ylabel(Y_LAB,fontsize=12)

    # Adjust layout
    plt.tight_layout()

In [None]:
import squarify

def plot_treemap_clusters(df, COL_DICT, TITLE=None, RATIO=5, MAX_AREA=None):
    # Ensure the DataFrame index is unique
    df = df.reset_index(drop=True)
    
    if MAX_AREA:
        MAX_AREA = f" (Max Area: {int(MAX_AREA)}%)"
        TITLE = TITLE + MAX_AREA
        
    COL_DICT['other'] = 'white'
    # Get the number of clusters
    num_clusters = df.shape[0]

    # Set up subplots
    rows = num_clusters // 8 + 1
    fig, axes = plt.subplots(rows, 8, figsize=( 4 + ((num_clusters // 5) * 1.5),((12/5)/RATIO)*12))
    
    plt.suptitle(TITLE, fontsize=20, y=1.02)
    
    # Flatten axes array for easy iteration
    if rows > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Iterate over each cluster
    for idx in range(num_clusters):
        # Select the data for the current cluster
        cluster_data = df.iloc[idx]
        
        # Filter out zero values
        cluster_data = cluster_data[cluster_data > 0]

        # Prepare data for the treemap
        sizes = cluster_data.values
        labels = cluster_data.index
        colors = [COL_DICT[label] for label in labels]

        # Create treemap
        squarify.plot(sizes=sizes,  color=colors, ax=axes[idx], alpha=1, pad=0,ec = 'black',
                        
                        norm_x = 100, 
                        norm_y = 5)
        
        axes[idx].set_title(f'C{idx}', size=34)
        axes[idx].axis('off')
        
        axes[idx].set_yticks([RATIO])

    # Hide unused subplots
    for j in range(idx + 1, len(axes)):
        fig.delaxes(axes[j])

    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Adjust top to make room for suptitle

# TOP CV (CT max) and BOTTOM CV

In [None]:
name = 'TOP'
N_TOP = 4000
with open(f'../data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    TOP_CV = pickle.load(f)
    
#
name = 'BOTTOM'
with open(f'../data/RNA_CV/{name}{N_TOP}/dict.pkl', 'rb') as f:
    BOTTOM_CV = pickle.load(f)
# 
STABLE = [gene for gene_list in BOTTOM_CV.values() for gene in gene_list]
TOP_CV['STABLE'] = STABLE

In [None]:
CV_COL_DICT = {'RNA_ESC': '#405074', 'RNA_MES': '#7d5185', 'RNA_CP': '#c36171', 'RNA_CM': '#eea98d',
                'STABLE':'#B4CD6F'}
export_legend(CV_COL_DICT, f'{DIR_FIG}/CV_legend.pdf')


### Intersection as % of list in the cluster

In [None]:
INTERSECTION_CV = cluster_intersection(TOP_CV, GENE_CLUSTERS)
COL_SUM = INTERSECTION_CV.sum()
INTERSECTION_CV = INTERSECTION_CV.div(COL_SUM, axis=1) * 100
#
plot_barplots_clusters(INTERSECTION_CV, CV_COL_DICT, TITLE=f'CellType max CV overlaps (% of list)',
                        Y_LAB='% of list', ADD_LABELS=False)
plt.savefig(f'{DIR_FIG}/Intersection_TOP{N_TOP}_CV_perc_list.pdf', format="pdf", bbox_inches="tight")



### Intersection as % of cluster belonging to the list

In [None]:
len_values = {int(key): GENE_CLUSTERS[key]['len'] for key in GENE_CLUSTERS}
CLUSTERS_LEN = pd.Series(len_values).sort_index()
#
INTERSECTION_CV = cluster_intersection(TOP_CV, GENE_CLUSTERS)
INTERSECTION_CV = INTERSECTION_CV.div(CLUSTERS_LEN, axis=0) * 100
#
plot_barplots_clusters(INTERSECTION_CV, CV_COL_DICT, TITLE=f'CellType max CV overlap (% of cluster)',
                        Y_LAB='% of cluster', ADD_LABELS=False)
plt.savefig(f'{DIR_FIG}/Intersection_TOP{N_TOP}_CV_perc_cluster.pdf', format="pdf", bbox_inches="tight")



In [None]:
len_values = {int(key): GENE_CLUSTERS[key]['len'] for key in GENE_CLUSTERS}
CLUSTERS_LEN = pd.Series(len_values).sort_index()
#
INTERSECTION_CV = cluster_intersection(TOP_CV, GENE_CLUSTERS)
INTERSECTION_CV = INTERSECTION_CV.div(CLUSTERS_LEN, axis=0) * 100
INTERSECTION_CV['other'] = 100 - INTERSECTION_CV.sum(axis=1)
INTERSECTION_CV

plot_treemap_clusters(INTERSECTION_CV, CV_COL_DICT, TITLE=f'CellType max CV overlap (% of cluster)',
                        RATIO=1.5, MAX_AREA=100)
plt.savefig(f'{DIR_FIG}/Intersection_TOP{N_TOP}_CV_perc_cluster_TREEMAP.pdf', format="pdf", bbox_inches="tight")


### Gene sets

In [None]:
GS_COL_DICT = assign_palette('rainbow', GENE_SETS)

#### % of list

In [None]:
INTERSECTION_GS = cluster_intersection(GENE_SETS, GENE_CLUSTERS)
COL_SUM = INTERSECTION_GS.sum()
INTERSECTION_GS = INTERSECTION_GS.div(COL_SUM, axis=1) * 100
#
export_legend(GS_COL_DICT, f'{DIR_FIG}/GENE_SETS_legend.pdf')
plot_barplots_clusters(INTERSECTION_GS, GS_COL_DICT, TITLE=f'GeneSets overlap (% of list)',
                        Y_LAB='% of list', ADD_LABELS=False)
plt.savefig(f'{DIR_FIG}/Intersection_GENE_SETS_perc_list.pdf', format="pdf", bbox_inches="tight")


#### % of cluster

In [None]:
INTERSECTION_GS = cluster_intersection(GENE_SETS, GENE_CLUSTERS)
INTERSECTION_GS = INTERSECTION_GS.div(CLUSTERS_LEN, axis=0) * 100
MAX = INTERSECTION_GS.sum(axis=1).max()
INTERSECTION_GS['other'] = MAX - INTERSECTION_GS.sum(axis=1)

export_legend(GS_COL_DICT, f'{DIR_FIG}/GENE_SETS_{LABELS_COL}_legend.pdf')
plot_treemap_clusters(INTERSECTION_GS, GS_COL_DICT, TITLE=f'GeneSets overlap (% of cluster)',
                        RATIO=1.5, MAX_AREA=MAX )
plt.savefig(f'{DIR_FIG}/Intersection_GENE_SETS_perc_cluster_TREEMAP.pdf', format="pdf", bbox_inches="tight")


### Bivalent/Active Gonzalez

In [None]:
INTERSECTION_BIV = cluster_intersection(GONZALEZ, GENE_CLUSTERS)
COL_SUM = INTERSECTION_BIV.sum()
INTERSECTION_BIV = INTERSECTION_BIV.div(COL_SUM, axis=1) * 100
#
export_legend(GONZALEZ_COL_DICT, f'{DIR_FIG}/Gonzalez_legend.pdf')
plot_barplots_clusters(INTERSECTION_BIV, GONZALEZ_COL_DICT, TITLE=f'Gonzalez overlap (% of list)',
                        Y_LAB='% of list', ADD_LABELS=False)
plt.savefig(f'{DIR_FIG}/Intersection_Gonzalez_perc_list.pdf', format="pdf", bbox_inches="tight")


#### % of cluster

In [None]:
INTERSECTION_BIV = cluster_intersection(GONZALEZ, GENE_CLUSTERS)
INTERSECTION_BIV = INTERSECTION_BIV.div(CLUSTERS_LEN, axis=0) * 100
MAX = INTERSECTION_BIV.sum(axis=1).max()
INTERSECTION_BIV['other'] = MAX - INTERSECTION_BIV.sum(axis=1)
plot_treemap_clusters(INTERSECTION_BIV, GONZALEZ_COL_DICT, TITLE=f'{LABELS_COL} Gonzalez overlap (% of cluster)',
                        RATIO=1.5, MAX_AREA=MAX )
plt.savefig(f'{DIR_FIG}/Intersection_Gonzalez_perc_cluster_TREEMAP.pdf', format="pdf", bbox_inches="tight")

# Individual cluster analysis

In [None]:
CODE = pd.read_csv(f'../data/{ID}_analysis/CODE.csv',index_col='GENE')
#
LOG = pd.read_csv(f'../data/matrices/ALL/ALL_X_FC.csv').set_index('GENE')
X_FEATURES = LOG.filter(regex='^(?!.*FC).*$').columns
gene_to_cluster = {}
for cluster_id, gene_list in GENE_CLUSTERS.items():
    for gene in gene_list['gene_list']:
        gene_to_cluster[gene] = cluster_id

# Map the cluster IDs to the LOG DataFrame
LOG[GMM_LABELS] = LOG.index.map(gene_to_cluster).astype(int)
CODE[GMM_LABELS] = LOG.index.map(gene_to_cluster).astype(int)

In [None]:
def plot_boxplots_clusters(df, CLUSTER_COL, FEATURE_PREFIXES, HM_COL_DICT=None, X_LINE=0, TITLE=None, X_LAB='Z-score'):
    # Get unique GMM_LABELS
    unique_labels = df[CLUSTER_COL].sort_values().unique()
    num_labels = len(unique_labels)
    VMIN, VMAX = df[FEATURE_PREFIXES].min().min(), df[FEATURE_PREFIXES].max().max()

    MAPPED_COL = None
    if HM_COL_DICT is not None:
        MAPPED_COL = list(map(lambda x: next((v for k, v in HM_COL_DICT.items() if k in x), None), FEATURE_PREFIXES))

    # Update layout for halved columns
    num_rows = 10  # Updated to maintain the same number of total plots
    num_cols = (num_labels + num_rows - 1) // num_rows  # Adjust to get fewer columns
    
    # Set up subplots with updated figsize and column count
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(2 * num_cols, 4 + num_rows * len(FEATURE_PREFIXES) * 0.20))
    
    plt.suptitle(TITLE, fontsize=20, y=1.02)
    # Iterate over each unique value of GMM_LABELS
    for idx, label in enumerate(unique_labels):
        # Subset the dataframe for the current label
        subset_df = df[df[CLUSTER_COL] == label]
        num_genes = subset_df.shape[0]
        # Melt the dataframe to long format for seaborn
        melted_df = subset_df.melt(id_vars=CLUSTER_COL, value_vars=FEATURE_PREFIXES, var_name='HM')
        # Plot the boxplot
        col_index = idx % num_cols
        row_index = idx // num_cols

        sns.violinplot(ax=axes[row_index, col_index], y='HM', x='value', data=melted_df, hue='HM', palette=MAPPED_COL, saturation=0.8,
                        width=0.7, inner=None, linewidth=0, cut=1)
        plt.setp(axes[row_index, col_index].collections, alpha=.8)

        sns.boxplot(ax=axes[row_index, col_index], y='HM', x='value', data=melted_df, hue='HM', palette=MAPPED_COL,
                    width=0.3, linewidth=1, showcaps=1,
                    medianprops={"color": "w", "linewidth": 1},
                    boxprops=dict(facecolor="black", alpha=0.4),
                    whiskerprops=dict(color="black", alpha=0.5),
                    flierprops=dict(marker='.', markerfacecolor='grey', markersize=1, alpha=0.5))
        # Set title and labels
        axes[row_index, col_index].set_title(f'C{label} ({num_genes})', size=15)
        axes[row_index, col_index].set_ylabel('')
        axes[row_index, col_index].axvline(x=X_LINE, color='black', linestyle='-', linewidth=0.5)

        # Set y limits equal for all plots
        axes[row_index, col_index].set_xlim(VMIN, VMAX)  # Adjust the limits as per your data range
        if row_index == (num_labels // num_cols) - 1:
            axes[row_index, col_index].set_xlabel(X_LAB, fontsize=12)
        else:
            axes[row_index, col_index].set_xlabel('')
        # Remove y-axis label for all but the first plot
        if col_index != 0:
            axes[row_index, col_index].set_yticklabels('')

        # Change the size of xticklabels
        for tick in axes[row_index, col_index].get_xticklabels():
            tick.set_fontsize(12)  # Change xtick label size

    # Hide any unused subplots
    for i in range(idx + 1, num_rows * num_cols):
        fig.delaxes(axes.flat[i])

    plt.tight_layout()


In [None]:
plot_boxplots_clusters(CODE, GMM_LABELS, PREFIXES, HM_COL_DICT=HM_COL_DICT ,X_LINE=0, TITLE=GMM_LABELS, X_LAB='Z-score')
plt.savefig(f'{DIR_FIG}/Feature_dist_avg.png', format="png", bbox_inches="tight")

In [None]:
plot_boxplots_clusters(LOG, GMM_LABELS, X_FEATURES, HM_COL_DICT=HM_COL_DICT ,X_LINE=0, TITLE=GMM_LABELS, X_LAB='log10(x)')

In [None]:
RNA = pd.read_csv(f'../data/matrices/RNA_FPKM_TSS_2500_FILT.csv').set_index('GENE')
CHIP_levels = pd.read_csv(f'../data/matrices/ChIP_TSS2500_RAW_FILT.csv').set_index('GENE')
CHIP_levels

In [None]:
def gene_trend(MAIN,GENE_LIST,CT_LIST,CT_COL_DICT,SAVE_PREFIX,Y_LAB='FPKMs'):
    num_genes = len(GENE_LIST)
    grid_size = math.ceil(math.sqrt(num_genes))
    
    plt.figure(figsize=(grid_size*3, grid_size*3))

    for i,GENE_NAME in enumerate(GENE_LIST):
        CT_REG = '|'.join(CT_LIST)
        Series = MAIN.loc[GENE_NAME]
        CT = Series.index.str.extract(f'({CT_REG})')[0]
        REP = Series.index.str.extract(f'(\d)')[0]
        DF = pd.DataFrame({Y_LAB:Series.values, 'CT':CT, 'REP':REP})
        #
        #plt.figure(figsize=(3,3))
        plt.subplot(grid_size, grid_size, i+1)

        plt.title(GENE_NAME)
        sns.stripplot(data=DF,x='CT',y=Y_LAB,hue='CT',palette=CT_COL_DICT,
                    s=12, alpha=1, legend=False,linewidth=0)
        sns.lineplot(data=DF,x='CT',y=Y_LAB, err_style=None,
                    color='black',linewidth=1, dashes=(2, 2))
        sns.despine(left=1,bottom=0,top=1)
        plt.xlabel('')
        plt.xticks([])
        plt.yticks(  np.ceil( [ 0, max(DF[Y_LAB]) ])  )
        plt.ylim(np.ceil([0,  max(DF[Y_LAB])*1.1]))
        # Plot y-axis label only for the first column in each row
        if i % grid_size != 0:  plt.ylabel('')
        plt.tight_layout()
    plt.savefig(f'../figures/gene_trend/{SAVE_PREFIX}.pdf', format="pdf", bbox_inches="tight")
    

### For each cluster plot per gene trends (25 rnd genes selected)

In [None]:
#take ~20 min to run
import random
random.seed(42)

MAX_LEN = 16
for c, value in GENE_CLUSTERS.items():
    NAME= f'C{c}'
    GENE_LIST = value['gene_list']
    if len(GENE_LIST) > MAX_LEN:
        print(NAME, len(GENE_LIST))
        GENE_LIST = random.sample(GENE_LIST, MAX_LEN)
        
    ! mkdir -p '../figures/gene_trend/clusters/{NAME}'
    #HM_LIST = ['H3K27ac', 'H3K27me3', 'H3K4me3','WCE']
    #for HM in HM_LIST:
    #    gene_trend(CHIP_levels.filter(regex=HM)
    #                ,GENE_LIST= GENE_LIST,CT_LIST=CT_LIST, CT_COL_DICT=CT_COL_DICT, 
    #                SAVE_PREFIX=f'clusters/{NAME}/{HM}', Y_LAB=HM)
    gene_trend(RNA,GENE_LIST= GENE_LIST,CT_LIST=CT_LIST, CT_COL_DICT=CT_COL_DICT, 
                    SAVE_PREFIX=f'clusters/{NAME}/RNA')
    

In [None]:
! mkdir -p '../figures/gene_trend/clusters'
NAME= 'test'
GENE_LIST= ['Atp6v0e2', 'Axl', 'Card19', 'Cebpb', 'Cnn1', 'Copz2', 'Coq8a',
        'Creb3l1', 'Creld1', 'Ctsf', 'Ddr1', 'Dgat2', 'Dpysl3', 'Dusp3', 'Eml2',
        'Fndc10']


gene_trend(RNA,GENE_LIST= GENE_LIST,CT_LIST=CT_LIST, CT_COL_DICT=CT_COL_DICT, SAVE_PREFIX=f'clusters/{NAME}')

In [None]:
RNA