### CTS_goplot
- cts_goplot: Plot the enriched GO categories of the CTS genes

In [None]:
import sys
import os
import pandas as pd
import numpy as np

sys.path.append('../scripts')
from plot_helpers import *

%load_ext autoreload
%autoreload 2

In [None]:
#Import the stability data
outdir = '../Figures/CTS'
os.makedirs(outdir, exist_ok=True)

In [None]:
#abbreviate some words to make descriptions fit
def shorten_des(des):
    des = des.replace('plasma membrane', 'P.M.')
    des = des.replace('regulation', 'reg.')
    des = des.replace('positive', '+ve')
    des = des.replace('negative', '-ve')
    des = des.replace('differentiation', 'diff.')
    des = des.replace('morphogenesis', 'morph.')
    des = des.replace('response', 'resp.')
    des = des.replace('involved', '')
    #normalize whitespaces
    return ' '.join(des.split())

def plot_GO(go_result, outname, title=''):
    '''Read in the output from ClusterProfiler and plot the GO categories.'''
    
    #Read in the GO results and prepare for plotting
    go_df = pd.read_csv(go_result)
    go_df['go_n'] = go_df.apply(lambda x: int(x['GeneRatio'].split('/')[0]), axis=1)
    go_df['subset_n'] = go_df.apply(lambda x: int(x['GeneRatio'].split('/')[1]), axis=1)
    go_df['num_in_bg'] = go_df.apply(lambda x: int(x['BgRatio'].split('/')[0]), axis=1)
    go_df['gene_ratio'] = go_df['go_n']/go_df['subset_n']
    
    #limit plotting to most significant and more specific GO categories
    go_df = go_df[(go_df['p.adjust']<0.01) & (go_df['num_in_bg']<200)].copy()
    go_df['-plog10'] = -go_df['p.adjust'].apply(np.log10)
    go_df['edited_des'] = go_df['Description'].map(shorten_des)
    go_df.sort_values(by='-plog10', ascending=True, inplace=True)
    
    #Plot significant GO categories
    scale = 5
    go_df['markersize'] = go_df['go_n']*scale
    #Dynamically update rcparams
    plt.rcParams['font.size'] = 6
    plt.rcParams['xtick.labelsize'] = 6
    plt.rcParams['ytick.labelsize'] = 6
    plt.rcParams['axes.titlesize'] = 6
    plt.rcParams['axes.labelsize'] = 6

    fig = plt.figure(figsize=(dfig*2, dfig*2), constrained_layout=True)
    ncols = 15
    gs = fig.add_gridspec(ncols = ncols, nrows = 2)
    gs.update(wspace=0.01) # set the spacing between axes.
    ax = fig.add_subplot(gs[:, 0:ncols-1])
    cbar_ax = fig.add_subplot(gs[0, ncols-1])
    sizes_ax = fig.add_subplot(gs[1, ncols-1])
    #https://matplotlib.org/stable/tutorials/colors/colormaps.html
    im = ax.scatter(x='gene_ratio', y='edited_des', s='markersize', c='-plog10', cmap='viridis', data=go_df)
    ax.set_xlabel('fraction of gene set')
    ax.set_title(title, fontweight='bold')
    axpad_y = 0.3
    axpad_x = 0.005
    new_ylim = (ax.get_ylim()[0]-axpad_y, ax.get_ylim()[1]+axpad_y)
    new_xlim = (ax.get_xlim()[0]-axpad_x, ax.get_xlim()[1]+axpad_x)
    #new_xlim = (ax.get_xlim()[0]-axpad, ax.get_xlim()[1]+axpad)
    #https://stackoverflow.com/questions/6063876/matplotlib-colorbar-for-scatter
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='vertical', anchor=(0,0), pad=0)
    cbar.set_label('-log'r'$_{10}$' ' pvalue')
    loc = plticker.MultipleLocator(base=0.5)
    cbar.ax.yaxis.set_major_locator(loc)
    #https://stackoverflow.com/questions/24164797/creating-a-matplotlib-scatter-legend-size-related/36852302
    #https://stackabuse.com/change-font-size-in-matplotlib/
    num_elements=2
    #recover the sizes to label the legend with
    to_convert = im.legend_elements("sizes", num=num_elements)[1]
    sizes_num = list(map(lambda y: int("".join([z for z in y if z.isnumeric()])), to_convert))
    sizes_num = [int(i/scale) for i in sizes_num]
    ax.set_ylim(new_ylim)
    ax.set_xlim(new_xlim)
    sizes_ax.legend(im.legend_elements("sizes", num=num_elements)[0], sizes_num, loc='lower left', bbox_to_anchor=(0,0), borderaxespad=0, title='# genes')
    #sizes_ax.legend(*im.legend_elements("sizes", num=4), loc='lower left', bbox_to_anchor=(0,0), borderaxespad=0, title='# genes')

    sizes_ax.spines['bottom'].set_visible(False)
    sizes_ax.spines['left'].set_visible(False)
    sizes_ax.xaxis.set_visible(False)
    sizes_ax.yaxis.set_visible(False)

    plt.savefig('%s.%s' % (outname, out_fmt), dpi = out_dpi)

In [None]:
plot_GO(os.path.join(outdir, 'genesets', 'CTS_10stable_genes_BP_0.5.csv'), os.path.join(outdir, 'CTS_stable'), 'stable CTS genes')

In [None]:
plot_GO(os.path.join(outdir, 'genesets', 'CTS_10unstable_genes_BP_0.5.csv'), os.path.join(outdir, 'CTS_unstable'), 'unstable CTS genes')