### Set parameters and names

In [None]:
from CustomObjects import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'


In [None]:
INPUT_GENES    = 'ALL'
INPUT_FEATURES = 'X_FC'
INPUT_NORM     = 'z'
CODINGS_SIZE = 6

ID     = f'{CODINGS_SIZE}D_{INPUT_GENES}_{INPUT_FEATURES}_{INPUT_NORM}'
METHOD = 'VAE'
k      = 80
LABELS_COL = f'GMM_{METHOD}_{k}'
DIR_DATA= f'../data/{ID}_analysis/{LABELS_COL}'
DIR_FIG = f'../figures/{ID}_analysis/{METHOD}/Clusters_k{k}/TSSplots'
! mkdir -p {DIR_FIG}
# set cmap
GMM_LABELS = f'GMM_{METHOD}_{k}'


In [None]:
METADATA = pd.read_csv('../../01_Mapping/data/ChIP_NARROW.csv')

In [None]:
REP1 = METADATA[(METADATA['REP']==1) | (METADATA['SAMPLE_ID']=='H3K27me3_ESC_2_Wamstad_2013_SE')].reset_index(drop=True)                            
IDS = REP1['SAMPLE_ID']                         
IDS.to_csv(f'{DIR_DATA}/BAM_IDs_Rep1.list', index=False, header=False)
IDS

In [None]:
# Function to smooth the Y values using rolling average
def smooth_signal(y_values, window_size=200):
    return y_values.rolling(window=window_size, min_periods=1, center=True).mean()


In [None]:
for c in range(k):
    i=1
    TARGETs=['H3K4me3','H3K27ac', 'H3K27me3']
    dashes_dict = {'H3K4me3': '', 'H3K27ac': '', 'H3K27me3': '', 'WCE': (3,1.5)}
    style_dict = {'H3K4me3': 2.5, 'H3K27ac': 2.5, 'H3K27me3': 2.5, 'WCE': 1.5}
    CT_COL_DICT={'ESC': '#394b7d', 'MES': '#723386', 'CP': '#c33a56', 'CM': '#f19a7a'}

    plt.figure(figsize=(3,8))
    plt.suptitle(f'Cluster {c}', fontsize=12)
    for TARGET in TARGETs:
        plt.subplot(3,1,i)

        DF = pd.DataFrame(columns=['TARGET', 'CT', 'X', 'Y'])
        for CT in CT_LIST:
            # TARGET
            TARGET_NAME = REP1[(REP1['TARGET']==TARGET) & (REP1['CELL_TYPE']==CT)]['SAMPLE_ID'].values[0]
            TARGET_NAME = f"{DIR_DATA}/TSSplots/C{c}_{TARGET_NAME}_TSSplot_2500/TSSprofile_C{c}_{TARGET_NAME}_2500.txt"
            TMP = pd.read_csv(TARGET_NAME, header=None,sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values
            TMP['TARGET'] = TARGET
            TMP['CT'] = CT
            
            DF = pd.concat([DF, TMP], ignore_index=True)
            # WCE
            WCE_NAME = REP1[(REP1['TARGET']=='WCE') & (REP1['CELL_TYPE']==CT)]['SAMPLE_ID'].values[0]
            WCE_NAME = f"{DIR_DATA}/TSSplots/C{c}_{WCE_NAME}_TSSplot_2500/TSSprofile_C{c}_{WCE_NAME}_2500.txt"

            TMP = pd.read_csv(WCE_NAME, header=None,sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values

            TMP['TARGET'] = 'WCE'
            TMP['CT'] = CT
            DF = pd.concat([DF, TMP], ignore_index=True)
        
        sns.lineplot(data=DF,x='X',y='Y',hue='CT' ,palette=CT_COL_DICT, alpha=1,
                    legend=0,linewidth=2.5,
                    style='TARGET',dashes=dashes_dict, 
                    size='TARGET',sizes=style_dict) 
            
        
        title_text = TARGET  # Title text
        title_fontsize = 16  # Title fontsize
        # Add a colored marker or frame to the title
        title_color = HM_COL_DICT[TARGET]  # Color from palette (assuming CT_LIST is defined)
        title_marker_width = 1  # Width of the colored marker or frame
        title_pad = 2# Padding between title and marker or frame
        
        # Draw a colored frame beneath the title
        plt.text(0.5, 1.1, title_text, horizontalalignment='center', verticalalignment='center',
                    transform=plt.gca().transAxes, color='white', fontsize=title_fontsize, 
                    bbox=dict(facecolor=title_color, edgecolor=title_color, linewidth=title_marker_width, pad=title_pad,alpha=0.9)
                    )
        
        # Set yticks to 0 and the maximum value rounded to 1 decimal place
        max_y = DF['Y'].max()
        plt.yticks(ticks=[0, max_y], labels=['0',round(max_y, 1)], fontsize=14)
        #plt.xticks(ticks=[-2500,0,+2500], labels=[-2500,'TSS',+2500], fontsize=14)
        plt.axvline(x=0, color='silver', linestyle='-', linewidth=1,ymin=0, ymax=0.05)

        if i == 3:
            plt.xticks(ticks=[-2500, 0, +2500], labels=[-2500, 0, 2500], fontsize=14)
            plt.xlabel('TSS distance (bp)', fontsize=14)
        else:
            plt.xticks([])  # Remove xticks for other plots
            plt.xlabel('')
        plt.ylabel('')
        sns.despine()
        plt.grid(False)
        plt.tight_layout()
        i+=1
    plt.savefig(f'{DIR_FIG}/C{c}.pdf', format="pdf", bbox_inches="tight")
    break



### Exted CT to different subplots

In [None]:
# Initialize the dictionary to store the maximum Y values for each TARGET and CT across all clusters
max_y_dict = {TARGET: {CT: 0 for CT in CT_LIST} for TARGET in TARGETs}

# Compute the maximum Y values for each TARGET and CT across all clusters
for c in range(k):
    for TARGET in TARGETs:
        for CT in CT_LIST:
            DF = pd.DataFrame(columns=['TARGET', 'CT', 'X', 'Y'])

            # TARGET
            TARGET_NAME = REP1[(REP1['TARGET'] == TARGET) & (REP1['CELL_TYPE'] == CT)]['SAMPLE_ID'].values[0]
            TARGET_NAME = f"{DIR_DATA}/TSSplots/C{c}_{TARGET_NAME}_TSSplot_2500/TSSprofile_C{c}_{TARGET_NAME}_2500.txt"
            TMP = pd.read_csv(TARGET_NAME, header=None, sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values
            TMP['TARGET'] = TARGET
            TMP['CT'] = CT
            DF = pd.concat([DF, TMP], ignore_index=True)

            # WCE
            WCE_NAME = REP1[(REP1['TARGET'] == 'WCE') & (REP1['CELL_TYPE'] == CT)]['SAMPLE_ID'].values[0]
            WCE_NAME = f"{DIR_DATA}/TSSplots/C{c}_{WCE_NAME}_TSSplot_2500/TSSprofile_C{c}_{WCE_NAME}_2500.txt"
            TMP = pd.read_csv(WCE_NAME, header=None, sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values
            TMP['TARGET'] = 'WCE'
            TMP['CT'] = CT
            DF = pd.concat([DF, TMP], ignore_index=True)

            current_max_y = DF['Y'].max()
            if current_max_y > max_y_dict[TARGET][CT]:
                max_y_dict[TARGET][CT] = current_max_y
            

In [None]:
# Plotting
for c in range(k):
    plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
    plt.suptitle(f'Cluster {c}', fontsize=16)
    
    subplot_idx = 1
    for target_idx, TARGET in enumerate(TARGETs):
        # Add independent subplot for TARGET title in the first column of each row
        plt.subplot(len(TARGETs), len(CT_LIST) + 1, subplot_idx)
        plt.text(0.7, 0.5, TARGET, horizontalalignment='center', verticalalignment='center',color='white',
                    fontsize=18, bbox=dict(facecolor=HM_COL_DICT[TARGET], edgecolor=HM_COL_DICT[TARGET], 
                    linewidth=1, pad=2, alpha=0.9), transform=plt.gca().transAxes)
        plt.axis('off')
        subplot_idx += 1

        for CT in CT_LIST:
            plt.subplot(len(TARGETs), len(CT_LIST) + 1, subplot_idx)
            subplot_idx += 1

            DF = pd.DataFrame(columns=['TARGET', 'CT', 'X', 'Y'])

            # TARGET
            TARGET_NAME = REP1[(REP1['TARGET'] == TARGET) & (REP1['CELL_TYPE'] == CT)]['SAMPLE_ID'].values[0]
            TARGET_NAME = f"{DIR_DATA}/TSSplots/C{c}_{TARGET_NAME}_TSSplot_2500/TSSprofile_C{c}_{TARGET_NAME}_2500.txt"
            TMP = pd.read_csv(TARGET_NAME, header=None, sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values
            TMP['TARGET'] = TARGET
            TMP['CT'] = CT
            DF = pd.concat([DF, TMP], ignore_index=True)

            # WCE
            WCE_NAME = REP1[(REP1['TARGET'] == 'WCE') & (REP1['CELL_TYPE'] == CT)]['SAMPLE_ID'].values[0]
            WCE_NAME = f"{DIR_DATA}/TSSplots/C{c}_{WCE_NAME}_TSSplot_2500/TSSprofile_C{c}_{WCE_NAME}_2500.txt"
            TMP = pd.read_csv(WCE_NAME, header=None, sep='\t', names=['X', 'Y'])
            TMP['Y'] = smooth_signal(TMP['Y'])  # Smooth Y values
            TMP['TARGET'] = 'WCE'
            TMP['CT'] = CT
            DF = pd.concat([DF, TMP], ignore_index=True)

            sns.lineplot(data=DF, x='X', y='Y', hue='CT', palette=CT_COL_DICT, alpha=1,
                            legend=False, linewidth=2.5, style='TARGET', dashes=dashes_dict,
                            size='TARGET', sizes=style_dict)

            max_y = max_y_dict[TARGET][CT]
            plt.yticks(ticks=[ max_y], labels=[round(max_y, 1)], fontsize=14)
            plt.axvline(x=0, color='silver', linestyle='-', linewidth=1, ymin=0, ymax=0.05)

            # Plot CT title framed only in the first row
            if target_idx == 0:
                plt.text(0.5, 1.1, CT, horizontalalignment='center', verticalalignment='center',
                            transform=plt.gca().transAxes, color='white', fontsize=18,
                            bbox=dict(facecolor=CT_COL_DICT[CT], edgecolor=CT_COL_DICT[CT], linewidth=1, pad=2, alpha=0.9))

            if subplot_idx > len(TARGETs) * (len(CT_LIST) + 1) - len(CT_LIST):
                plt.xlabel('')
                plt.xticks(ticks=[-2500, 0, 2500], labels=['-2500', 'TSS', '2500'], fontsize=14)
            else:
                plt.xlabel('')
                plt.xticks([])

            plt.ylabel('')
            sns.despine()
            plt.grid(False)
            plt.tight_layout()
    
    plt.savefig(f'{DIR_FIG}/C{c}_ext.pdf', format="pdf", bbox_inches="tight")
    