In [None]:
#This notebook for generating heatmaps is based on the test example "test" 

In [None]:
pgf_with_custom_preamble = {
        "font.family": "serif", # use serif/main font for text elements
        "font.size": 10,
        "text.usetex": True,    # use inline math for ticks
        "pgf.rcfonts": False   # don't setup fonts from rc parameters
    }

cm = 1/2.54

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use("pgf")
import pandas as pd
import numpy as np
import seaborn as sns
import json
import matplotlib.ticker as plticker
from matplotlib.legend import _get_legend_handles_labels

In [None]:
#INPUTVALUES 

BASENAME ="test"
START= "from_2D/test"
ROUNDS= 9
INTLEN = list(range(3,20,2))
EXTEND = "expand_test"
SIMROUND = 3
CLUSTER = 2
nstep = 5000


In [None]:
#LOAD DATA

fclust=list()
df_sum = list()
errorload=list()

for ROUND in range(ROUNDS):
    for CL in range(CLUSTER):
        for SR in range(SIMROUND):
            FILE = str(START+
                       "/cluster"+str(CL)+
                       "/"+str(ROUND)+
                       "/"+str(SR+1)+
                       "/"+BASENAME+"c"+str(CL)+"_"+EXTEND+"_"+str(ROUND)+"_"+str(SR+1)+".csv")
            try:
                df= pd.read_csv(FILE, sep="\t",skiprows=[i for i in range(1,3)])
                df['cluster'] = CL
                df['nstep'] = nstep
                df['constraint'] = int(INTLEN[ROUND])
                df['expansionstep'] = ROUND
                df['ROUND'] = SR+1
                fclust.append(df)

            except:
                errorload.append(FILE)
                pass
            
if len(errorload) != 0:
    print("Probably there were no further expansion possible - please check:")
    print (*errorload, sep="\n")

df_c = pd.concat(fclust)
df_sum.append(df_c)

df_sumup = pd.concat(df_sum)

df_sumup["energy_diff"]= df_sumup["energy_values_plus_restraint_score"] \
                                            - df_sumup["energy_value"]


In [None]:
#PREPARE DATA FOR PLOTTING

percent_list = list()
heatmap1prep=list()
heatmap2prep=list()
heatmap3prep=list()

def round_mean(x):
    return round(x.mean(), 0)

for ROUND in range(ROUNDS):
    round_df = df_sumup[(df_sumup["constraint"] == INTLEN[ROUND])]
  
    #sorted by the perfect intercation helix length
    values1 = round_df.groupby(["len_interaction"]).agg({'len_interaction': ['count'],
                                                            'energy_value':['min','median','mean'],
                                                            'energy_diff':['min','max','median','mean']
                                                            ,'intra_chainA': [round_mean]
                                                            ,'intra_chainB': [round_mean]
                                                            }
                                                        ).reset_index()
    values1.columns =['continuous_interaction', 'nr',
                    'energy_min', 'energy_median','energy_mean',
                    'dif_min','dif_max', 'dif_median','dif_mean'
                    ,'intraA_mean', 'intraB_mean'
                    ]

    values1['percent_col'] = (values1.nr/values1.nr.sum())*100
    
    if ROUND == 0:
        total = values1.nr.sum()
        
    values1['percent_tot'] = (values1.nr.sum()/total)*100
    values1['sum_per_expand'] = (values1.nr.sum())
    values1['constraint'] = INTLEN[ROUND]
    values1.sort_values(by=['continuous_interaction'])
    values1=values1.astype({'energy_min': float, 'energy_median': float,'energy_mean': float})
    heatmap1prep.append(values1)

    #sorted by the total intercation helix length (e.g. incl. bulges)
    values2 = round_df.groupby(["interaction_countbp"]).agg({'interaction_countbp': ['count'],
                                                            'energy_value':['min','median','mean'],
                                                            'energy_diff':['min','max','median','mean']
                                                            ,'intra_chainA': [round_mean]
                                                            ,'intra_chainB': [round_mean]
                                                            }
                                                        ).reset_index()
    values2.columns =['interaction', 'nr',
                    'energy_min', 'energy_median','energy_mean',
                    'dif_min','dif_max', 'dif_median','dif_mean'
                    ,'intraA_mean', 'intraB_mean'
                    ]
    values2['percent_col'] = (values2.nr/values2.nr.sum())*100
    values2['percent_tot'] = (values2.nr.sum()/total)*100
    values2['sum_per_expand'] = (values2.nr.sum())
    values2['constraint'] = INTLEN[ROUND]
    values2.sort_values(by=['interaction'])
    values2=values2.astype({'energy_min': float,
                            'energy_median': float,
                            'energy_mean': float
                            })
    heatmap2prep.append(values2)
    
    #evaluate intramolecular kissing hairpin structure
    values3 = round_df.groupby(["intra_chainA"]).agg({'intra_chainA': ['count'],
                                                    'intra_chainB': ['count'],
                                                    }
                                                ).reset_index()

    values3.columns =['intramolecular', 'nr', 'nrB']
    heatmap3prep.append(values3)

    percent_list.append(round(((values2.nr.sum()/total)*100)))
    
dfheatmap1= pd.concat(heatmap1prep)
dfheatmap2= pd.concat(heatmap2prep)
dfheatmap3= pd.concat(heatmap3prep)

In [None]:
#SAVE DATA IF WANTED

dfheatmap1.to_csv(START+'/'+BASENAME+'_perfecthelix-interaction.csv', encoding='utf-8', sep='\t')
dfheatmap2.to_csv(START+'/'+BASENAME+'_interaction.csv', encoding='utf-8', sep='\t')
dfheatmap3.to_csv(START+'/'+BASENAME+'_intramolecular.csv', encoding='utf-8', sep='\t')

with open(START+'/'+BASENAME+'.per', "w") as FILE:
    for s in percent_list:
        FILE.write(str(s) +"\n")
        

In [None]:
#PLOTTING CONFIGURATIONS

#Title
title = BASENAME+str(nstep)

#Plotting
count_type = 'percent_col' #choose between 'nr' or 'percent_col' or 'percent_tot'
energy_type = 'energy_min' #choose between:'energy_median', 'energy_mean', 'energy_min'
constraint_type = 'dif_median' #choose between:'dif_median', 'dif_mean', 'dif_min'

#Label descriptions
count_label = 'structures [\%]'  
energy_label = 'min energy'
dif_label = 'penalty'
tot_label = 'start [\%]'

#Label min/max
min_e, max_e = dfheatmap1["energy_min"].min(), dfheatmap1["energy_min"].max()
min_p, max_p = dfheatmap1["dif_mean"].min(), dfheatmap1["dif_mean"].max()
max_ia = max(dfheatmap1["intraA_mean"].max(),
             dfheatmap1["intraB_mean"].max()
            )

#Colouring
blue = sns.light_palette("darkblue", as_cmap=True)
blue.set_under(color="w", alpha=None)

green = sns.light_palette("green", as_cmap=True)
green.set_under(color="w", alpha=None)

olive = sns.light_palette("olivedrab", as_cmap=True)
olive.set_under(color="w", alpha=None)

red = sns.light_palette("firebrick", reverse=True, as_cmap=True)
red.set_over(color="w", alpha=None)

cmap_reversed = mpl.colormaps.get_cmap('pink_r')

purple = sns.light_palette("indigo", as_cmap=True)
purple.set_under(color="w", alpha=None)

rocket = sns.color_palette("rocket", as_cmap=True)
rocket.set_under(color="w", alpha=None)

intra =sns.cubehelix_palette(8,start=2, rot=-0.3, dark=0.1, light=0.9)



In [None]:
### Various functions for plotting

In [None]:
def label(df,length,index):
    """
    Return a list of x- and y-axis labels 
    
    :param df: dataframe for which the labels will be generated 
    :param length: sum of the expansion rounds that were possible 
                   corresponds to the length of the percent_list
    :param index: choose between 'interaction' or continuous_interaction'    
    """
    
    labelx = np.unique(df['constraint'].to_numpy())   

    labelxc5 = list()
    xlinec5 = list()
    count = 0            
            
    for n in range(length*5):
        if (n % 5 )== 0:
            if n != 0:
                xlinec5.append(n)
            labelxc5.append(0)
            labelxc5.append(labelx[count])

            if labelx[count] == np.amax(labelx):
                labelxc5.append(0)
                break
            else:
                labelxc5.append(0)
            count += 1
            
    labely = np.unique(df[index].to_numpy())   
    labely.sort()
    
    return labelx, labelxc5, labely, xlinec5

In [None]:
def heatmap(df,index):
    """
    Return the pivot tables for seaborn heatmap plots
    c5 -> stack 5 tables together: nr, energy, penalty, intramolecular basepairs A & B
    
    :param df: dataframe for which the labels will be generated 
    :param index: choose between 'interaction' or continous_interaction'    
    """  
   
    heatmap_nr = pd.pivot_table(df,
                               index=index,
                               columns='constraint',
                               values= count_type,
                               aggfunc=np.sum,
                               fill_value=0)

    heatmap_e = pd.pivot_table(df,
                                    index=index,
                                    columns='constraint',
                                    values=energy_type,
                                    aggfunc=np.sum,
                                    fill_value=0)
    heatmap_dif = pd.pivot_table(df,
                                     index=index,
                                     columns='constraint',
                                     values=constraint_type,
                                     aggfunc=np.sum,
                                     fill_value=0)
    
    heatmap_intraA = pd.pivot_table(df,
                                index=index,
                                columns='constraint',
                                values='intraA_mean',
                                aggfunc=np.sum)
    
    heatmap_intraB = pd.pivot_table(df,
                                index=index,
                                columns='constraint',
                                values='intraB_mean',
                                aggfunc=np.sum)
 
  
    heatmapc5_nr = np.dstack((heatmap_nr,
                            np.zeros_like(heatmap_nr),
                            np.zeros_like(heatmap_nr),
                            np.zeros_like(heatmap_nr),
                            np.zeros_like(heatmap_nr))
                            ).reshape(heatmap_nr.shape[0],-1)

    heatmapc5_e =np.dstack((np.zeros_like(heatmap_e),
                           np.zeros_like(heatmap_e),
                           np.zeros_like(heatmap_e),
                           heatmap_e,
                           np.zeros_like(heatmap_e))
                                ).reshape(heatmap_e.shape[0],-1)

    heatmapc5_dif = np.dstack((np.zeros_like(heatmap_dif),
                              np.zeros_like(heatmap_dif),
                              np.zeros_like(heatmap_dif),
                              np.zeros_like(heatmap_dif),
                              heatmap_dif)
                            ).reshape(heatmap_dif.shape[0],-1)
    
    heatmapc5_intraA = np.dstack((np.zeros_like(heatmap_intraA),
                                heatmap_intraA,
                                np.zeros_like(heatmap_intraA),
                                np.zeros_like(heatmap_intraA),
                                np.zeros_like(heatmap_intraA)
                               )
                            ).reshape(heatmap_intraA.shape[0],-1)  
  
    heatmapc5_intraB = np.dstack((np.zeros_like(heatmap_intraB),
                                np.zeros_like(heatmap_intraB),
                                heatmap_intraB,
                                np.zeros_like(heatmap_intraB),
                                np.zeros_like(heatmap_intraB)
                               )
                            ).reshape(heatmap_intraB.shape[0],-1)
    
    
    maskc5 = np.vstack([np.arange(heatmapc5_nr.shape[1])]* heatmapc5_nr.shape[0]) % 5


    return (heatmap_nr, 
            heatmapc5_nr, heatmapc5_e, heatmapc5_dif,heatmapc5_intraA, heatmapc5_intraB,
            maskc5)


In [None]:
def plot(nr, index, 
         percent_list,
         labelx, labely,
         title, 
         sizel, sizeb, 
         ax1set, ax2set,
         color,scale,
         text
         ):
    """
    Plot the distribution of the interaction extension success as a Seaborn-Heatmap 
    x-axis: length of the interaction helix (bps)
    y-axis: respective constrained interaction length (bps)
    purple bar:  percentage of the n_cluster x n_run - structures 
                 that are allowed to continue in the current extension step
                 
    :param nr: heatmap with the respective percentage
    :param index: choose between 'interaction' or continous_interaction' 
    :param percent_list: list for the purple bar
    :param labelx: x-axis description 
    :param labely: y-axis description 
    :param title: title for the plot and name for the output pdf/pgf
    :param sizel, sizeb: size of the plotting page
    :param ax1set: dimensions for the heatmap
    :param ax2set: dimensions for the purple bar
    :param color: color for distribution of structures within the heatmap
    :param scale: scaling factor for the color bars
    :param text: dimensions fot the purple bar description     
    """
    
    
    fig = plt.figure(figsize=(sizeb*cm, sizel*cm))
    ax1 = fig.add_axes(ax1set)
    ax2 = fig.add_axes(ax2set)

    plt.rcParams.update(pgf_with_custom_preamble)

    sns.heatmap(heatmap_nr,
                ax = ax1,
                cmap=color,
                linewidth =1,
                cbar_kws={'label': count_label,"shrink":scale},
                square=True,
                cbar=True,
                vmin=0.0000001,
                vmax=100,
                xticklabels=False,
                yticklabels=labely)
    
    sns.heatmap([percent_list],
                ax = ax2,
                linewidth =1,
                cmap=purple,
                vmin=0.0000001,
                vmax=100,
                fmt="",
                square=False,
                cbar=False,
                annot=True,
                xticklabels=labelx,
                yticklabels=False)

    ax2.set_xlabel('constraint interaction bps')
    ax1.set(ylabel = 'interacting bps', title=str(title+' - '+index)) 

    plt.gcf().text(text[0], text[1], "start [%]")
    fig.savefig(START+'/'+title+'-'+index+'.pdf')
    fig.savefig(START+'/'+title+'-'+index+'.pgf')
    plt.close()


In [None]:
def plot5(h5_nr, h5_e, h5_p, h5_intraA, h5_intraB,
          index,
          percent_list, 
          min_e, max_e,
          min_p, max_p,
          max_ia,
          labelxc5, labely, xlinec5,
          title, mask,
          sizel, sizeb, ax1set, ax2set,
          color,
          scale,
          text
         ):

    """
    Plot the distribution of the interaction extension success as a Seaborn-Heatmap
    (incl. information regarding the SimRNA energy, used penalty, 
    length of the intramolecular kissing hairpin stems)
    
    x-axis: length of the interaction helix (bps)
    y-axis: respective constrained interaction length (bps)
    purple bar:  percentage of the n_cluster x n_run - structures 
                 that are allowed to continue in the current extension step
                 
    :param nr: heatmap with the respective percentage
    :param index: choose between 'interaction' or continous_interaction' 
    :param percent_list: list for the purple bar
    :param min_e, max_e: min and max of the SimRNA energy within the respective analyzed dataframe
    :param min_p, max_p: min and max of the penalty within the respective analyzed dataframe
    :param max_ia: max intramolecular kissing hairpin helix length within the respective analyzed dataframe
    :param labelx: x-axis description 
    :param labely: y-axis description 
    :param title: title for the plot and name for the output pdf/pgf
    :param sizel, sizeb: size of the plotting page
    :param ax1set: dimensions for the heatmap
    :param ax2set: dimensions for the purple bar
    :param color: color for distribution of structures within the heatmap
    :param scale: scaling factor for the color bars
    :param text: dimensions fot the purple bar description     
    """
    
    fig = plt.figure(figsize=(sizeb*cm, sizel*cm))
    ax1 = fig.add_axes(ax1set)
    ax2 = fig.add_axes(ax2set)

    plt.rcParams.update(pgf_with_custom_preamble)

    sns.heatmap(h5_intraA,
                ax = ax1,
                cmap=intra,
                linewidth=2,
                cbar_kws={'label': 'intramolecular basepairs',"shrink":scale, "pad":-0.07},
                vmin=-0.5,
                vmax=max_ia,
                square=True,
                cbar=True,
                annot=True,
                mask = mask-1,
                xticklabels=False,
                yticklabels=labely)  
    
    sns.heatmap(h5_nr,
                ax = ax1,
                cmap=color,
                linewidth=2,
                cbar_kws={'label': count_label,"shrink":scale, "pad":-0.05},
                fmt="",
                square=True,
                cbar=True,
                vmin=0.0000001,
                vmax=100,
                mask = mask,
                xticklabels=False,
                yticklabels=labely)
    
    sns.heatmap(h5_intraB,
                ax = ax1,
                cmap=intra,
                linewidth=2,
                vmin=-0.5,
                square=True,
                cbar=False,
                annot=True,
                mask = mask-2,
                xticklabels=False,
                yticklabels=labely)   

    sns.heatmap(h5_p,
                ax = ax1,
                cmap='pink_r',
                linewidth=2,
                cbar_kws={'label': dif_label,"shrink":scale, "pad":-0.02},
                fmt="",
                vmin=min_p,
                vmax=max_p,
                square=True,
                cbar=True,
                mask = mask-4,
                xticklabels=False,
                yticklabels=labely)

    
    sns.heatmap(h5_e,
                ax = ax1,
                cmap="afmhot",
                linewidth=2,
                cbar_kws={'label': energy_label,"shrink":scale},
                fmt="",
                vmin=min_e,
                vmax=max_e,
                square=True,
                cbar=True,
                mask = mask-3,
                xticklabels=False,
                yticklabels=labely)

    sns.heatmap([percent_list],
                ax = ax2,
                linewidths=2,
                cmap=purple,
                vmin=0.0000001,
                vmax=100,
                linewidth=1,
                fmt="",
                square=False,
                cbar=False,
                annot=True,
                xticklabels=labelx,
                yticklabels=False)

    for n in xline5:
        ax1.axvline(x=n, color='lightgrey')
        
    ax2.set_xlabel('constraint interaction bps')
    ax1.set(ylabel = 'interacting bps', title=title+' - '+index) 
    
    plt.gcf().text(text[0], text[1], "start [%]")

    fig.savefig(START+'/'+title+'-'+index+'_5.pdf')
    fig.savefig(START+'/'+title+'-'+index+'_5.pgf')
    plt.close()

In [None]:
#PLOTTING THE INTERACTION EXTENSION DISTRIBUTION FOR THE "continuous_interaction" (perfect helix interaction)

index = "continuous_interaction"
color = blue

labelx, labelx5, labely, xline5 = label(dfheatmap1, len(percent_list), index)

heatmap_nr, h5_nr, h5_e, h5_p, h5_intraA, h5_intraB, mask5 = heatmap(dfheatmap1,index)

#Plotsize
sizel, sizeb = 15,15
ax1set =[0.05, 0.15, 0.95, 0.80] 
ax2set = [0.095, 0.08, 0.72, 0.065]
scale=0.5
text=[0.83,0.108]


plot(heatmap_nr, index,
     percent_list,
     labelx, labely,
     title, 
     sizel, sizeb, 
     ax1set, ax2set,
     color,scale,
     text
    )

#Plotsize
sizel, sizeb = 11,37
ax1set =[0.05, 0.14, 0.95, 0.85] 
ax2set = [0.055, 0.28, 0.545, 0.06]
scale=0.4
text=[0.62,0.3]

plot5(h5_nr, h5_e, h5_p, h5_intraA, h5_intraB,
      index,
      percent_list,
      min_e, max_e,
      min_p, max_p,
      max_ia,
      labelx5, labely, xline5,
      title, mask5,
      sizel, sizeb, 
      ax1set, ax2set,
      color,
      scale,
      text
      )



In [None]:
#PLOTTING THE INTERACTION EXTENSION DISTRIBUTION FOR ALL interactions incl. e.g. bulges

index = "interaction"
color = green
labelx, labelxc5, labely, xlinec5 = label(dfheatmap2, len(percent_list),index)

heatmap_nr, h5_nr, h5_e, h5_p, h5_intraA, h5_intraB, mask5 = heatmap(dfheatmap2,index)

#Plotsize
sizel, sizeb = 15,13
ax1set =[0.05, 0.15, 0.95, 0.80] 
ax2set = [0.12, 0.08, 0.69, 0.065]
scale=0.5
text=[0.83,0.108]


plot(heatmap_nr, index,
     percent_list,
     labelx, labely,
     title, 
     sizel, sizeb, 
     ax1set, ax2set,
     color,scale,
     text
    )


#Plotsize
sizel, sizeb = 11,37
ax1set =[0.05, 0.14, 0.95, 0.85] 
ax2set = [0.053, 0.25, 0.545, 0.06]
scale=0.44
text=[0.62,0.255]

plot5(h5_nr, h5_e, h5_p, h5_intraA, h5_intraB,
      index,
      percent_list,
      min_e, max_e,
      min_p, max_p,
      max_ia,
      labelx5, labely, xline5,
      title, mask5,
      sizel, sizeb, 
      ax1set, ax2set,
      color,
      scale,
      text
      )