#### stability and function
- Plot the stability percentile by GO slim categories from function
- Plot the relationship between stability percentile and average TPM for these various categories
##### Comparing the GO Summary Ribbon mapping table (Flybase) to the GO Slim categories from AGR reveals that they are very similar.
The GO Slim categories for MF which are not represented in the Flybase table are:
- enzyme regulator activity
- DNA-binding TF activity
- TF binding
- carbohydrate derivative binding

In contrast, they include:
- small molecule binding
- molecular function regulator
- transcription regulator activity

It seems less confusing that they have only chosen one TF-related category, so I'm going to go with the Flybase GO summary ribbons for this figure

In [None]:
#Imports
import sys
import os
import pandas as pd
import seaborn as sns
import numpy as np
import math
import scipy.stats as stats
from scipy.stats import percentileofscore
import goatools
import pybiomart as pb
from goatools.base import get_godag
from goatools.gosubdag.gosubdag import GoSubDag
from goatools.anno.genetogo_reader import Gene2GoReader
from collections import defaultdict
import scipy.stats
import scipy.signal
import gffutils
from gffutils import FeatureNotFoundError
import matplotlib as mpl

sys.path.append('../scripts')
from plot_helpers import *
from plotting_fxns import extract_gene_vals
from utilities import load_dataset
db = gffutils.FeatureDB(gffutils_db)

%load_ext autoreload
%autoreload 2

In [None]:
outdir = '../Figures/GO'
os.makedirs(outdir, exist_ok=True)

In [None]:
# Flybase function ribbons:
# https://wiki.flybase.org/wiki/FlyBase:Gene_Report
# list of GO Summary Ribbons
mf_ribbon_file = '../../resources/flybase_files/flybase_MF_summaryribbons.csv'
cc_ribbon_file = '../../resources/flybase_files/flybase_CC_summaryribbons.csv'
mf_ribbon_df = pd.read_csv(mf_ribbon_file)
cc_ribbon_df = pd.read_csv(cc_ribbon_file)
mf_ribbon_df['GO ID(s)'] = mf_ribbon_df['GO ID(s)'].str.split(', ')
cc_ribbon_df['GO ID(s)'] = cc_ribbon_df['GO ID(s)'].str.split(', ')
mf_slim_dict = mf_ribbon_df.set_index('Tile Name')['GO ID(s)'].to_dict()
cc_slim_dict = cc_ribbon_df.set_index('Tile Name')['GO ID(s)'].to_dict()
del mf_slim_dict['other molecular_function']
del cc_slim_dict['other cellular_component']

In [None]:
# STEP 1: dictionary of GO terms as key and the genes related as values
dataset = pb.Dataset(name='dmelanogaster_gene_ensembl', host='jan2020.archive.ensembl.org/')
df_11 = dataset.query(attributes=['go_id', 'ensembl_gene_id'])
df_12 = df_11.dropna()
list_11 = df_12.groupby('GO term accession')['Gene stable ID'].apply(set)
ens_dic = list_11.to_dict()
#There are some difference between the Flybase vs Ensembl annotations
#Ensembl annotations include slightly more genes, will use those

In [None]:
# STEP 1b: Use the annotation file from Flybase to get the genes associated to each GOterm
# allow all evidence codes, since I think that is what Flybase does for the annotation ribbons
#http://ftp.flybase.net/releases/FB2019_03/precomputed_files/go/gene_association.fb.gz
assoc_file = '../../resources/flybase_files/gene_association.fb'
assoc_df = pd.read_csv(assoc_file, header=None, skiprows=5, sep='\t')
#get ones without the NOT, colocalizes with, relationships, etc.
assoc_df = assoc_df.loc[pd.isnull(assoc_df[3])].copy()
fb_dic = assoc_df.groupby(4)[1].apply(set).to_dict()
# Seems like Ensembl and Flybase are returning different genesets, so let's collect both to compare

In [None]:
# STEP 2: Get the GO graph
godag = get_godag('go-basic.obo', optional_attrs={'relationship'})

In [None]:
# STEP 3: Find all descendendents of specified GO terms
#https://github.com/tanghaibao/goatools/blob/main/notebooks/parents_and_ancestors.ipynb
def get_subterms(slim_dict, go2genes, remove_noncoding=True):
    '''
    Get all the child GO terms starting with a dictionary of GO categories.
    go2genes is a dictionary mapping GO_ID -> set of genes
    '''
    GO_genes = {}
    optional_relationships = {'part_of',}
    
    for goterm in slim_dict:
        go_ids = slim_dict[goterm]
        gosubdag_r0 = GoSubDag(go_ids, godag, relationships=optional_relationships, prt=None)
        for i in go_ids:
            try:
                descendants = gosubdag_r0.rcntobj.go2descendants[i]
            except KeyError:
                descendants = set()
            geneset = set().union(*[go2genes[go_id] for go_id in descendants if go_id in go2genes])
            #also add parent term
            if i in go2genes:
                geneset.update(go2genes[i])
            GO_genes[goterm] = geneset
    
    if remove_noncoding:
        GO_genes = remove_noncoding_genes(GO_genes)

    return GO_genes

def remove_noncoding_genes(go_genes):
    '''
    Remove non-coding genes from the GO dict. db is a gffutils db.
    If not removed, then non-coding genes like tRNAs will make up a large portion of some categories,
    like RNA-binding proteins.
    '''
    new_genes = {}
    for cat in go_genes:
        new_genes[cat] = set()
        for gene in go_genes[cat]:
            try:
                if db[gene].attributes['gene_biotype'][0] == 'protein_coding':
                    new_genes[cat].add(gene)
            except FeatureNotFoundError:
                continue
    return new_genes
    
def calc_go_stab(GO_genes, stab_df):
    '''
    Find the median stability for different GO Slim categories.
    Also add the expression level in TPM
    '''
    go_stab_med = defaultdict(dict)
    for cat in GO_genes:
        these_genes = GO_genes[cat]
        go_stab_med[cat]['stability'] = stab_df[stab_df.index.isin(these_genes)].stab_percentile.median()
        go_stab_med[cat]['tpm_total'] = stab_df[stab_df.index.isin(these_genes)].tot_level.median()
        go_stab_med[cat]['num_genes'] = len(stab_df[stab_df.index.isin(these_genes)])
    go_stab = pd.DataFrame.from_dict(go_stab_med, orient='index')
    return go_stab

In [None]:
# Load stability data
rate_df = load_dataset('../Figures/summary_files/INSPEcT_rates.csv', '../Figures/summary_files/brain4sU_passed.csv')

# Get coding subset of the stability data to use for making the GO plots so that I can label them 'functional classes of mRNA'
coding_rate_df = rate_df.query('biotype=="protein_coding"')

print('len rate df', len(rate_df))
print('len coding rate df', len(coding_rate_df))

In [None]:
#Find the median stability for different GO Slim categories

#It does find all of the genes from Ensembl
mf_subterms_ens = get_subterms(mf_slim_dict, ens_dic)
cc_subterms_ens = get_subterms(cc_slim_dict, ens_dic)
#It does not find all the genes from Flybase because some of them are unannotated
mf_subterms_fb = get_subterms(mf_slim_dict, fb_dic)
cc_subterms_fb = get_subterms(cc_slim_dict, fb_dic)
#Use ensembl terms to write the output
mf_stab = calc_go_stab(mf_subterms_ens, coding_rate_df)
cc_stab = calc_go_stab(cc_subterms_ens, coding_rate_df)

In [None]:
#Write the output files using either ensembl (ens) or flybase (fb) versions
def write_golists(term_dic, sourcename, other_groups=None, go2genes=None):
    if other_groups is not None:
        term_dic2 = term_dic.copy()
        term_dic2.update(get_subterms(other_groups, go2genes))
    for s in term_dic2:
        os.makedirs(os.path.join(outdir, sourcename), exist_ok=True)
        with open(os.path.join(outdir, sourcename, '%s.txt' % s), 'w') as g:
            for gene in term_dic2[s]:
                g.write('%s\n' % gene)

#add any additional desired terms to report genes for here:
other_groups = {'mRNA binding': ['GO:0003729']}

write_golists(mf_subterms_ens, 'ens', other_groups=other_groups, go2genes=ens_dic)
write_golists(mf_subterms_fb, 'fb', other_groups=other_groups, go2genes=fb_dic)

In [None]:
# Turn data into longform, with stab profile and the names
# Limit this to the coding genes
cats = mf_stab.sort_values(by='stability').index
data_stab = []
data_tpm = []
for i in cats:
    data_stab.append(coding_rate_df.loc[coding_rate_df.index.isin(mf_subterms_ens[i]), 'stab_percentile'].values)
    data_tpm.append(coding_rate_df.loc[coding_rate_df.index.isin(mf_subterms_ens[i]), 'tot_level'].values)

l = []
for i in range(len(data_stab)):
    labels = [cats[i]]*len(data_stab[i])
    sdf = pd.DataFrame({'stab_percentile':data_stab[i], 'cat':labels, 'total_tpm':data_tpm[i]})
    l.append(sdf)
big_df = pd.concat(l)

In [None]:
# Correlation of total RNA levels with RNA stability
# This plot includes the non-coding RNAs
fig = plt.figure(figsize=(dfig, dfig), constrained_layout=True)
ax = fig.add_subplot(111)
x = rate_df['deg_rate'].apply(np.log10)
y = rate_df['tot_level'].apply(np.log10)
#rval, pval = stats.pearsonr(x, y)
rval, pval = stats.spearmanr(x, y)

r2_val_av = rval**2
#ax.scatter(x, y)
ax = sns.histplot(x=x, y=y, cmap='rocket', ax=ax, zorder=2)
# ax.text(0.05, 0.9, 'r'r'$^2$'' = %1.2f' % r2_val_av, fontsize = 8, transform=ax.transAxes)
ax.text(0.6, 0.9, 'r'r'$^2$'' = %1.2f' % r2_val_av, fontsize = 8, transform=ax.transAxes)
ax.set_ylabel('total RNA level\n(log'r'$_{10}$'' TPM)')
ax.set_xlabel('decay rate (log'r'$_{10}$'' 1 / min)')
ax.set_aspect('equal')
# yloc = plticker.MultipleLocator(base=2.0)
# xloc = plticker.MultipleLocator(base=5.0)
# ax.xaxis.set_major_locator(xloc)
# ax.yaxis.set_major_locator(xloc)
# A few genes with very low degradation rates make the x-axis extend very far.
ax.set_xlim(-5, 1)
plt.savefig('%s.%s' % (os.path.join(outdir, 'tot_vs_deg_scat'), out_fmt), dpi = out_dpi)

In [None]:
# #Examine distribution of TPM for all expressed genes
# ax=sns.histplot(stab_df['total_tpm'].apply(np.log10))

In [None]:
#Look at median stability and TPM of these categories
print('median TPM of expressed genes', rate_df['tot_level'].median())
print('median TPM of expressed genes, coding', coding_rate_df['tot_level'].median())

mf_stab.sort_values(by='stability')

In [None]:
#Plot ridgeline plot
num_cats = len(big_df['cat'].unique())
cat_pretty_names = {'transcription factor':'transcription\nfactor', 'carbohydrate binding':'carbohydrate\nbinding', 
                   'lipid binding':'lipid\nbinding', 'receptor binding':'receptor\nbinding', 'cytoskeleton binding':'cytoskeleton\nbinding',
                   'metal ion binding':'metal ion\nbinding', 'small molecule binding':'small molecule\nbinding',
                    'structural molecule':'structural\nmolecule'}
#Set the background to transparent to allow overlap:
#https://stackoverflow.com/questions/4581504/how-to-set-opacity-of-background-colour-of-graph-with-matplotlib
plt.rcParams.update({
    "figure.facecolor":  (0.0, 0.0, 0.0, 0.0),  # red   with alpha = 30%
    "axes.facecolor":    (0.0, 0.0, 0.0, 0.0),  # green with alpha = 50%
    "savefig.facecolor": (0.0, 0.0, 0.0, 0.0),  # blue  with alpha = 20%
})
pal = sns.cubehelix_palette(num_cats, rot=-.25, light=.7)
#sns.set_palette(pal)
fig = plt.figure(figsize=(dfig, dfig*2))
gs = fig.add_gridspec(ncols=1, nrows = len(cats))
for i in range(len(cats)):
    ax = fig.add_subplot(gs[i])
    if i==0:
        ax.set_title('functional classes of mRNAs')
    ax = sns.kdeplot(x='stab_percentile', data=big_df[big_df['cat']==cats[i]], bw_adjust=.5, clip_on=False,
      fill=True, alpha=1, linewidth=1.5, color=pal[i])
    ax = sns.kdeplot(x='stab_percentile', data=big_df[big_df['cat']==cats[i]], clip_on=False, color="w", lw=2, bw_adjust=.5)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_yticks([])
    ax.set_ylabel('')
    if i != len(cats)-1:
        ax.set_xticks([])
        ax.set_xlabel('')
    else:
        ax.set_xlabel('stability percentile')
    if cats[i] in cat_pretty_names:
        label = cat_pretty_names[cats[i]]
    else:
        label = cats[i]
    # try change label -.25->-.28
    ax.text(-0.28, .3, label, fontsize=5, fontweight="bold",
            ha="left", va="center", transform=ax.transAxes, color=pal[i])

fig.subplots_adjust(hspace=-.3)
#this adds space on the left side to make room for the labels
fig.subplots_adjust(left=0.2)
plt.savefig('%s.%s' % (os.path.join(outdir, 'GOslim_ridge'), out_fmt), dpi = out_dpi)

In [None]:
#How to get colors from color map
#https://stackoverflow.com/questions/28144142/how-can-i-generate-a-colormap-array-from-a-simple-array-in-matplotlib
def med_tpm(x, y, bins='', vmin=0, vmax=100, filter_size=5, polyorder=3, dim_x=200):
    '''
    Get the median TPM from the binned stability percentiles.
    Dim_x is the length that the returned array needs to be in the end
    '''
    stat, edges, binn = scipy.stats.binned_statistic(x, y, statistic='median', bins=bins, range=None)
    scale_factor = dim_x/(len(bins)-1)
    # interpolate missing data from bins that have no genes
    # If you don't interpolate, then the smoothing doesn't work on any windows which contain NaNs
    # interpolate doesn't fill in the values at the edges, which is why some categories start/end earlier
    stat = pd.Series(stat).interpolate().values
    stat2 = scipy.signal.savgol_filter(stat, filter_size, polyorder)
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    colors_stab = plt.cm.magma(norm(stat2))
    c2 = colors_stab.repeat(scale_factor, axis=0)
    return c2

# colors_stab1 = med_tpm(big_df.query('cat == "transcription factor"')['stab_percentile'], big_df.query('cat == "transcription factor"')['total_tpm'], vmin=0, vmax=100, bins=bins, filter_size=21, dim_x=200)
# colors_stab2  = med_tpm(big_df.query('cat == "receptor"')['stab_percentile'], big_df.query('cat == "receptor"')['total_tpm'], vmin=0, vmax=100, bins=bins, filter_size=21, dim_x=200)

In [None]:
# Plot GO categories with expression overlaid, for this one put the color bar on the right
# big_df['log_tpm'] = big_df['total_tpm'].apply(np.log10)
fig = plt.figure()
pal = sns.cubehelix_palette(num_cats, rot=-.25, light=.7)
#sns.set_palette(pal)
cmap = mpl.cm.get_cmap('magma')
dark = cmap(0)
fig = plt.figure(figsize=(dfig*1.5, dfig*2))
gs = fig.add_gridspec(ncols=1, nrows = len(cats))
gs.update(left=0.15, right=0.645)
# vmin = big_df['log_tpm'].min()
# vmax = big_df['log_tpm'].max()
vmin = 0
vmax = 100
colors_dict = {}
for i in range(len(cats)):
    sdf = big_df[big_df['cat']==cats[i]].copy()
    ax = fig.add_subplot(gs[i])
    if i==0:
        ax.set_title('functional classes of mRNAs')
    ax = sns.kdeplot(x='stab_percentile', data=sdf, clip_on=False, color=dark, lw=0.75, bw_adjust=.5, ax=ax)
    ydata = ax.lines[0].get_ydata()
    xdata = ax.lines[0].get_xdata()
    #Now get the smoothed median TPM for the genes
    #len of bins is n+1 relative to the number of bins
    bins = np.arange(0, 101, 1)
    # Commented out for faster testing:
    colors_stab = med_tpm(sdf['stab_percentile'], sdf['total_tpm'], vmin=vmin, vmax=vmax, bins=bins, filter_size=21, dim_x=200)
    # colors_stab = med_tpm(sdf['stab_percentile'], sdf['log_tpm'], bins=bins, filter_size=21, dim_x=200, vmin=vmin, vmax=vmax)
    colors_dict[i] = colors_stab

    for j in range(len(xdata)):
        ax.fill_between(xdata[j:j+2], ydata[j:j+2], color=colors_stab[j])

    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_yticks([])
    ax.set_ylabel('')
    if i != len(cats)-1:
        ax.set_xticks([])
        ax.set_xlabel('')
    else:
        ax.set_xlabel('stability percentile')
    if cats[i] in cat_pretty_names:
        label = cat_pretty_names[cats[i]]
    else:
        label = cats[i]
    axis_to_fig = ax.transAxes + fig.transFigure.inverted()
    points_axis = axis_to_fig.transform((0, 0.3))
    ax.text(0, points_axis[1], label, fontsize=5, fontweight="bold",
            ha="left", va="center", transform=fig.transFigure, color='k')
# Overlap the plots vertically
fig.subplots_adjust(hspace=-.2)
# This adds space on the left side to make room for the labels
# fig.subplots_adjust(left=0.2)
gs2 = fig.add_gridspec(ncols=1, nrows=1)
gs2.update(left=0.7, right=0.75)
cbar_ax = fig.add_subplot(gs2[0])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap='magma', norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax)), 
                    cax=cbar_ax, orientation='vertical', label='expression level (total RNA TPM)')
ticks = cbar_ax.get_yticks()
dic = {100: r'$\geq$''100'}
labels = [int(ticks[i]) if t not in dic.keys() else dic[t] for i,t in enumerate(ticks)]

cbar_ax.set_yticklabels(labels)
plt.savefig('%s.%s' % (os.path.join(outdir, 'GOslim_sunsetridge2'), out_fmt), dpi = out_dpi)

In [None]:
# Narrower figure version with colorbar on the top
fig = plt.figure()
pal = sns.cubehelix_palette(num_cats, rot=-.25, light=.7)
#sns.set_palette(pal)
cmap = mpl.cm.get_cmap('magma')
dark = cmap(0)
fig = plt.figure(figsize=(dfig, dfig*2))
gs = fig.add_gridspec(ncols=1, nrows = len(cats))
for i in range(len(cats)):
    sdf = big_df[big_df['cat']==cats[i]].copy()
    ax = fig.add_subplot(gs[i])
    # if i==0:
    #     ax.set_title('functional classes of mRNAs')
    ax = sns.kdeplot(x='stab_percentile', data=sdf, clip_on=False, color=dark, lw=0.75, bw_adjust=.5, ax=ax)
    ydata = ax.lines[0].get_ydata()
    xdata = ax.lines[0].get_xdata()
    #Now get the smoothed median TPM for the genes
    #len of bins is n+1 relative to the number of bins
    ##bins = np.arange(0, 101, 1)
    bins = np.arange(0, 101, 1)
    colors_stab = med_tpm(sdf['stab_percentile'], sdf['total_tpm'], vmin=0, vmax=100, bins=bins, filter_size=21, dim_x=200)
    for j in range(len(xdata)):
        ax.fill_between(xdata[j:j+2], ydata[j:j+2], color=colors_stab[j])
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_yticks([])
    ax.set_ylabel('')
    if i != len(cats)-1:
        ax.set_xticks([])
        ax.set_xlabel('')
    else:
        ax.set_xlabel('stability percentile')
    if cats[i] in cat_pretty_names:
        label = cat_pretty_names[cats[i]]
    else:
        label = cats[i]
    ax.text(-0.28, .3, label, fontsize=5, fontweight="bold",
            ha="left", va="center", transform=ax.transAxes, color='k')

fig.subplots_adjust(hspace=-.2)
#this adds space on the left side to make room for the labels
fig.subplots_adjust(left=0.2)
cbar_ax = fig.add_axes([0.22, 0.9, 0.63, 0.02])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap='magma', norm=mpl.colors.Normalize(vmin=0, vmax=100)), cax=cbar_ax, orientation='horizontal', label='expression level (total RNA TPM)')
cbar_ax.xaxis.set_ticks_position('top')
cbar_ax.xaxis.set_label_position('top')
cbar_ax.tick_params(axis='x',length=0)
cbar_ax.tick_params(axis='x', labelsize=5)
cbar_ax.set_xlabel(cbar_ax.get_xlabel(), fontsize=5)
ticks = cbar_ax.get_xticks()
dic = {100: r'$\geq$''100'}
labels = [int(ticks[i]) if t not in dic.keys() else dic[t] for i,t in enumerate(ticks)]
cbar_ax.set_xticklabels(labels)
plt.savefig('%s.%s' % (os.path.join(outdir, 'GOslim_sunsetridge'), out_fmt), dpi = out_dpi)

There is white under some regions of some lines because there are no TPM values in that range
for which to calculate the corresponding bin TPM.

In [None]:
#Look at the correlation between synthesis and total levels for different functional groups
#calculate the correlation between synthesis and total levels for each group:
#'Figures/Overview'
d = {}
this_df = coding_rate_df[['syn_rate', 'tot_level']].dropna(how='any')
d['all'] = {}
d['all']['r_val'], d['all']['p_val'] = stats.pearsonr(this_df['syn_rate'], this_df['tot_level'])
for i in mf_subterms_ens:
    sub_df = this_df.loc[this_df.index.isin(mf_subterms_ens[i])]
    d[i] = {}
    d[i]['r_val'], d[i]['p_val'] = stats.pearsonr(sub_df['syn_rate'], sub_df['tot_level'])

corr_df = pd.DataFrame.from_dict(d, orient='index')
corr_df

This shows that actually the correlation between synthesis rate and total RNA is higher for TF RNAs than other RNAs. This is a little counterintuitive, but if you have shorter half-lifes, then the synthesis rate will be more highly correlated with total leve than for a longer half-life gene.
So it doesn't really tell you how much of the total RNA level can be explained by RNA decay.

In [None]:
#Colorbar formatting tests
fig = plt.figure(figsize=(dfig, dfig*2))
gs = fig.add_gridspec(ncols=1, nrows = len(cats))
cbar_ax = fig.add_axes([0.22, 0.9, 0.63, 0.02])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap='magma', norm=mpl.colors.Normalize(vmin=0, vmax=100)), cax=cbar_ax, orientation='horizontal', label='expression level (total RNA TPM)')
cbar_ax.xaxis.set_ticks_position('top')
cbar_ax.xaxis.set_label_position('top')
cbar_ax.tick_params(axis='x',length=0, labelsize=5)
cbar_ax.set_xlabel(cbar_ax.get_xlabel(), fontsize=5)
plt.savefig('%s.%s' % (os.path.join(outdir, 'cbar_test'), out_fmt), dpi = out_dpi)