## MitoRibo Bayesian inference to obtain timescale posteriors
Author: Robert Ietswaart  
Date: 20220205  
License: BSD2.  
Load modules j3dl and activate virtual environment using j4RNAdecay on O2.  
Python v3.7.4

Source: `Timescale_Bayes_20210615.ipynb`  
For Erik's mito RNA kinetics project using mitoribo subcellular timelapse seq: Bayesian inference to get timescale posteriors based on Grand-Slam new to total RNA ratio (NTR) posteriors for individual genes on mitototal and mitoribosome fractions. 

In [1]:
import os
import numpy as np
import pandas as pd
import logging
import argparse
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
import seaborn as sns
from scipy.integrate import quad, cumulative_trapezoid 

# from __init__ import __version__
from __init__ import default_logger_format, default_date_format

import new_total_ratio_jit as ntr
import posteriors_jit as p

import numba as nb
from numba import jit
from numba.core import types
from numba.typed import Dict

In [2]:
# def main():
np.random.seed(12345)

parser = argparse.ArgumentParser(
    description='Run Bayesian timescale fitting on mitoribo subcellular timelapse seq')

args = parser.parse_args("")#EDIT: added "" as argument to run in ipynb instead of .py script

### Load input files

In [3]:
path = os.path.join('/n','groups','churchman','ri23','RNAdecay','GS20220204')
outpath = os.path.join('/n','groups','churchman','ri23','RNAdecay','Bayes20220204_mitoribo')

# Add a logger specific to the project and processing stage
logger = logging.getLogger('scTLseq_mitoribo')
log_file = os.path.join(outpath,'LogErr', 'scTLseq_mitoribo_20220204.log')
formatter = logging.Formatter(default_logger_format,
                              datefmt=default_date_format)
log_handler = logging.FileHandler(log_file)
log_handler.setFormatter(formatter)
logger.addHandler(log_handler)

fracs = ['tot','mitoribo']
fracs_model = {'tot': ['tot_fit'],
               'mitoribo': ['mitoribo_fit']}
frac_filename={'tot': '_tot_MTdsrRNA_t5MTMMinformed6_modeAll_PcMod3tot.tsv',
                'mitoribo': '_IP_MTdsrRNA_t5MTMMinformed6_modeAll_PcMod4IP.tsv'}
reps = ['TL8']
background_id = {r: '1' for r in reps}
time_mins = [0,15,30,60]
time_measured = np.asarray(time_mins[1:])
T_max = time_mins[-1] + 30 #only used for plotting continuous curves
time_cont = pd.Series(range(0, T_max)) #only used for plotting continuous curves

CI_para = ['alpha','beta']
OUT_TYPES = [' Mean', ' MAP', ' 0.975 quantile', ' 0.025 quantile'] # 
#the CIs must match alpha/2 and 1-alpha/2 in get_post_ci()

Timescales = ['T_mito_entry',
              'T_deg']

GS = dict()
for r in reps:
    for fr in fracs:
        filename = r + frac_filename[fr]  
        if os.path.exists(os.path.join(path, filename)):
            GS[r+fr]= pd.read_csv(os.path.join(path, filename), sep='\t')

### Posterior calculations

In [5]:
N_grid = 1000 #to divide [k_bound_lo,k_bound_hi] interval
k_bound_lo = 1e-4 #unit: min^-1: 1 per 7 days
k_bound_hi = 1e4 #unit: min^-1: 1 per 6 ms, if too restrictive: increase to 1e6

@jit(nb.float64(nb.float64, types.DictType(types.unicode_type, nb.float64[:])), 
     nopython=True)
def post_mitototal(kD, NTR):
    prob = 1
    for i in range(time_measured.shape[0]):
        t = time_measured[i]
        lam = ntr.lam_total_one_step(kD, t)        
        alpha_g_t = NTR['alpha'][i]
        beta_g_t = NTR['beta'][i]
        beta_pdf_g = p.beta_pdf(lam, alpha_g_t, beta_g_t)
        prob = prob * beta_pdf_g * p.det_jac_mitototal(kD, t)
    return prob   

@jit(nb.float64(nb.float64, nb.float64, types.DictType(types.unicode_type, nb.float64[:])), 
     nopython=True)
def post_mitoribo(kL, kD, NTR):
    prob = 1
    for i in range(time_measured.shape[0]):
        t = time_measured[i]
        lam = ntr.lam_ribo(kL, kD, t)       
        alpha_g_t = NTR['alpha'][i]
        beta_g_t = NTR['beta'][i]
        beta_pdf_g = p.beta_pdf(lam, alpha_g_t, beta_g_t)
        prob = prob * beta_pdf_g * p.det_jac_mitoribo_from_tot(kL, kD, t)
    return prob  

def get_conditional_dist(post, NTR, ci, nvar):
    '''
    Calculate the marginal posterior by integration over the previous rate with 
    ci as integration boundaries.
    post: multivariate posterior
    '''
    def marg_dist(x, NTR):
        return quad(post, ci[0][0], ci[0][1], args=(x, NTR), epsabs=1e-3)[0]
    cond_dist = marg_dist
#     logger.info('p_g_c_d1')
    return cond_dist

def get_post_norm_const(post, NTR):    
    #integrate over small subdomains
    omega = 0
    omega_err = 0
    for i in range(8):#if k_bound_lo and k_bound_hi: update 4 or 8
        k_int_lo = k_bound_lo * 10**(i)#(2*i)#
        k_int_hi = k_bound_lo * 10**(i+1)#(2*(i+1))#
        try:
            k_int, k_int_err = quad(post, k_int_lo, k_int_hi, args=(NTR), epsabs=1e-3)
            logger.info('norm_const %f %f' % (k_int, k_int_err))
            if np.isnan(k_int):
                logger.warning('norm_err1: %f %f' % (k_int, k_int_err))
                omega = np.nan
                omega_err = np.nan
                break
        except (IndexError, KeyError, ZeroDivisionError):
            logger.warning('norm_err2 %f %f' % (k_int, k_int_err))
            omega = np.nan
            omega_err = np.nan
            break
        omega += k_int
        omega_err += k_int_err
    
    if omega == 0:
        logger.warning('norm_err3: norm_const is zero, set to nan and abort')
        omega = np.nan
        omega_err = np.nan
    return omega, omega_err
    
def get_post_mean(post, NTR, norm_const):
    def integrand(x, ntr):#cannot be jitted bc it does not recognize function post
        return x * post(x, ntr) / norm_const
    
    #integrate over small subdomains
    k_mean = 0
    k_mean_err = 0
    for i in range(8):#if k_bound_lo and k_bound_hi: update 4 or 8
        k_int_lo = k_bound_lo * 10**(i)#(2*i)#
        k_int_hi = k_bound_lo * 10**(i+1)#(2*(i+1))#
        try:
            k_int, k_int_err = quad(integrand, k_int_lo, k_int_hi, args=(NTR), epsabs=1e-3)
            logger.info('mean %f %f' % (k_int, k_int_err))
            if np.isnan(k_int):
                logger.warning('mean_err1 %f %f' % (k_int_lo, k_int_hi))
                k_mean = np.nan
                k_mean_err = np.nan
                break
        except (IndexError, KeyError, ZeroDivisionError):
            logger.warning('mean_err2 %f %f' % (k_int_lo, k_int_hi))
            k_mean = np.nan
            k_mean_err = np.nan
            break
        k_mean += k_int
        k_mean_err += k_int_err
    
    if k_mean == 0:
        logger.warning('mean_err3: mean is zero, set to nan')
        k_mean = np.nan
        k_mean_err = np.nan
    return k_mean, k_mean_err

def get_post_ci(k_domain, post_grid):
    alpha = 0.05
    #CDF via cumulative trapezoid
    cdf = cumulative_trapezoid(post_grid, k_domain, initial=0) 
    omega = max(cdf)
    logger.info('omega %f' % omega)
    cdf = cdf / omega #pdf was already normalized, but just in case, normalize cdf as well
    
    k_temp = k_domain[cdf <= (alpha / 2)]
    if k_temp.size:
        ci_lo = max(k_temp)
    else:
        ci_lo = k_domain[0]
    k_temp = k_domain[cdf >= (1 - alpha / 2)]
    if k_temp.size:
        ci_hi = min(k_temp)
    else: 
        N_domain = len(k_domain)
        ci_hi = k_domain[N_domain -1]
    return ci_lo, ci_hi

def get_estimates_from_post(post, NTR, N_var, ci_k_prev=None):
    '''
    post : possibly multivariate posterior (function). If N_var > 1, the marginal
           distribution of the unknown rate, conditioned on ci_k_prev, will replace this
           multivariate posterior.
    N_var : number of rate arguments of the posterior.
    ci_k_prev : credible intervals of CIs of the previously fitted rates (list of tuples).
    ''' 
    if N_var > 1:
        post_new = get_conditional_dist(post, NTR, ci_k_prev, N_var)
        post = post_new     
    logger.info('g_e_f_p0')

    #determine normalization constant for (marginal) posterior probability distribution
    norm_const, err = get_post_norm_const(post, NTR)
    logger.info('norm_constant %f %f' % (norm_const, err))
    logger.info('g_e_f_p1')  
    
    if np.isfinite(norm_const):
        #Initialize the prior domain of rate
        k_domain = np.geomspace(k_bound_lo, k_bound_hi, num=N_grid)
        post_grid = np.empty(N_grid)
        for i in range(N_grid):
            post_grid[i] = post(k_domain[i], NTR) #evaluate posterior in a grid to get k_map and CIs
        post_grid = post_grid / norm_const
        logger.info('g_e_f_p2') 

        mean_k, mean_k_err = get_post_mean(post, NTR, norm_const)
        logger.info('g_e_f_p3')  
        map_k = k_domain[np.argmax(post_grid)]
        logger.info('g_e_f_p4')  

        ci_lo, ci_hi = get_post_ci(k_domain, post_grid)
        logger.info('g_e_f_p5')
    else:
        [mean_k, map_k, ci_lo, ci_hi] = [np.nan for i in range(4)]
        
    return np.asarray([mean_k, map_k, ci_lo, ci_hi]) #, post_grid]) #include post_grid for plotting post dist

@jit(nb.float64[:](types.unicode_type, types.unicode_type,
                types.DictType(types.unicode_type, nb.float64[:]), types.unicode_type), nopython=True) 
def get_model_pred(rr, frm, k, e_type):
    N_time = time_measured.shape[0]
    NTR_model = np.empty(N_time) 
    if e_type == ' Mean':
        idx = 0
    else:#'MAP'
        idx = 1
        
    if frm == 'tot_fit':
        for i in range(N_time):
            NTR_model[i] = ntr.lam_total_one_step(k[rr+'T_deg'][idx], 
                                                  time_measured[i])  
    elif frm == 'mitoribo_fit':
        for i in range(N_time):
            NTR_model[i] = ntr.lam_ribo(k[rr+'T_mito_entry'][idx], 
                                        k[rr+'T_deg'][idx], 
                                        time_measured[i])
    return NTR_model


@jit(nb.float64(nb.float64[:], nb.float64[:]), nopython=True) 
def get_chi2_jit(ntr_meas, ntr_model):
    eps = 1e-16
    chi2 = 0
    for i in range(ntr_meas.shape[0]):
        chi2 = chi2 + (ntr_model[i]-ntr_meas[i])**(2) / (ntr_model[i]+eps)
    return chi2

####  Write to file genes that have sufficient data to estimate timescales 

In [11]:
genes_w_rates = dict()
for r in reps:
    for fr in fracs:
        if fr == 'tot':
            genes_w_rates[r] = set(GS[r+fr]['Gene'])
            logger.info('%s %s %s' % (r, fr, len(genes_w_rates[r])))
        elif fr in ['mitoribo']:
            genes_w_rates[r] = genes_w_rates[r].union(GS[r+fr]['Gene'])
            logger.info('%s %s %s' % (r, fr, len(genes_w_rates[r])))
if len(reps) > 1:
    for r in reps[1:]:
        genes_w_rates = genes_w_rates[reps[0]].union(genes_w_rates[r])
else:
    genes_w_rates = genes_w_rates[reps[0]]
genes_w_rates = sorted(list(genes_w_rates))
logger.info('total number of genes: union of (if > 1) biological replicates: %d' % len(genes_w_rates))

filename = 'genes_w_rates.csv'
genes_df = pd.Series(genes_w_rates)
# genes_df.to_csv(os.path.join(outpath, filename), sep='\t',index=False, header=False)

In [9]:
logger.info('Run model parameter fitting for all genes')

#initialize the fit dictionary with columns
fits = dict()
fits['Gene'] = []
fits['Symbol'] = []
for r in reps:
    for timescale in Timescales:
        for suf in OUT_TYPES:
            fits[r+'.'+timescale+suf] = []

for r in reps:
    for fr in fracs:
            for frm in fracs_model[fr]:
                for estimate_type in OUT_TYPES[:2]:
                    fits[r+'.'+frm+estimate_type+'.chi2'] = []

for ensid in genes_w_rates:#[:1]:
    logger.info(ensid)
    symbol = ''
    #initialize k_fit per gene
    k_fit = Dict.empty(key_type=types.unicode_type,
                value_type=types.float64[:])
    for r in reps:
        for timescale in Timescales:
            k_fit[r + timescale] = np.asarray([np.nan for out in OUT_TYPES])
    
    for r in reps:  
        for fr in fracs:
                try:
                    
                    # The Dict.empty() constructs a typed dictionary.
                    # The key and value typed must be explicitly declared.
                    AB = Dict.empty(key_type=types.unicode_type,
                                    value_type=types.float64[:],)
                    ABORT = False
                         
                    gene_idx = GS[r+fr][GS[r+fr]['Gene']==ensid].index[0]
                    symbol = GS[r+fr]['Symbol'][gene_idx]
                    for ab in CI_para:
                        gs_ab_times = [r+'_'+str(t)+'m '+ab for t in time_mins[1:]]
                        AB[ab] = np.asarray(GS[r+fr].loc[gene_idx, gs_ab_times], dtype='float64')
                        if np.prod(np.isfinite(AB[ab])) == 0:
                            logger.warning('%s %s: abort because AB contains nonfinite data' % (r,fr))
                            ABORT = True
                            break

                    if not ABORT:
                        for frm in fracs_model[fr]:
                            logger.info('fitting rate for %s %s %s' % (r, fr, frm))
                            try:                        
                                if frm == 'tot_fit':
                                    k_fit[r+'T_deg'] = get_estimates_from_post(post_mitototal, AB, 1)   
                                elif frm == 'mitoribo_fit' and \
                                    sum(np.isnan(k_fit[r + 'T_deg'][2:])) == 0:
                                    k_fit[r+'T_mito_entry'] = get_estimates_from_post(post_mitoribo, AB, 2, 
                                                                        [k_fit[r+'T_deg'][2:4]])
    
                            except (IndexError, KeyError, ZeroDivisionError):
                                logger.warning('except error, so nan rates')

                except (IndexError, KeyError, ZeroDivisionError):
                    logger.warning('%s %s: no data, so nan rates' % (r, fr))

    #append timescales (1/rate) in fits dictionary            
    if r in reps:
        for timescale in Timescales:
            for i, out in enumerate(OUT_TYPES):
                fits[r+'.'+timescale+out].append((k_fit[r+timescale][i])**(-1))

    #Get predictions
    logger.info('get model fit chi2')
    if r in reps:
        for fr in fracs:           
            for estimate_type in OUT_TYPES[:2]:
                gs_times = [r+'_'+str(t)+'m'+estimate_type for t in time_mins[1:]]
#                 logger.info(gs_times)
                try:                        
                    gene_idx = GS[r+fr][GS[r+fr]['Gene']==ensid].index[0]
                    NTR = np.asarray(GS[r+fr].loc[gene_idx, gs_times], dtype='float64')
    #                 logger.warning(NTR)
                    for frm in fracs_model[fr]:                 
                        NTR_model = get_model_pred(r, frm, k_fit, estimate_type)
                        chi2 = get_chi2_jit(NTR, NTR_model)
                        fits[r+'.'+frm+estimate_type+'.chi2'].append(chi2)  
                except (IndexError, KeyError):
                    for frm in fracs_model[fr]:
                        fits[r+'.'+frm+estimate_type+'.chi2'].append(np.nan)
                    
    fits['Gene'].append(ensid)    
    fits['Symbol'].append(symbol)

filename = 'Bayes20220206_mitoribo_fit_timescales.tsv' 
logger.info('Write results to file %s' % filename)
fits_df = pd.DataFrame(fits)
# fits_df.to_csv(os.path.join(outpath, filename), sep='\t',index=False)

logger.info('end')

In [10]:
# fits_df 

# UNFINISHED BELOW: visualization for individual genes

### Select gene

In [8]:
gene = 'Hnrnpd'#'Nfat5'#'Lsr'#'Stard3'#'Txnrd3'#'Bysl'#'Alyref' #'Narf'#'Apoh' #'Tbx2'#'Apoh'#'Gapdh' # 'Myc' #'Gnai3' # #'Klf4' #'Tfam' ##'Alyref' #'Tfam' #'Vps50' #  # #'Uba3' ## ##

In [83]:
logger.info(gene)

k_fit = dict()
for r in reps:   
    for fr in fracs:
#         if r+fr+'top1000' in GS.keys():
            logger.info('%s %s' % (r, fr))

            # The Dict.empty() constructs a typed dictionary.
            # The key and value typed must be explicitly declared.
            AB = Dict.empty(key_type=types.unicode_type,
                             value_type=types.float64[:],)
            ABORT = False
            
            try:
                    gene_idx = GS[r+fr][GS[r+fr]['Symbol']==gene].index[0]
                    for ab in CI_para:
                        gs_times = [r+t+'.'+ab for t in time_id[1:]]
                        AB[ab] = np.asarray(GS[r+fr].loc[gene_idx,gs_times], dtype='float64')
                        logger.info('AB %s %s' % (ab, AB[ab]))
                        if np.prod(np.isfinite(AB[tc+ab])) == 0:
                            logger.warning('Abort because AB contains nonfinite data')
                            ABORT = True
                            break
                    if ABORT:
                        break

                if not ABORT:
                    for frm in fracs_model[fr]:
                        logger.info('%s %s %s' % (r, fr, frm))
#Change to:
#                         'tot': ['tot_fit'],
#                'mitoribo': ['mitoribo_fit']
                        
                        if frm == 'tot_fit':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_mitototal, AB, 1)   
                        elif frm == 'nucpl_fit':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_nucpl, TC_TYPES_gene, AB, 2, 
                                                                          [k_fit[red_r[r]+'chr_fit'][2:4]])
                        elif frm == 'nuc_fit':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_nuc, TC_TYPES_gene, AB, 1)
                        elif frm == 'cyto_fit_from_chr_nucpl':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_cyto_from_chr, TC_TYPES_gene, AB, 3, 
                                                                          [k_fit[red_r[r]+'chr_fit'][2:4],
                                                                           k_fit[red_r[r]+'nucpl_fit'][2:4]])
                        elif frm == 'cyto_fit_from_nuc':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_cyto_from_nuc, TC_TYPES_gene, AB, 2, 
                                                                          [k_fit[red_r[r]+'nuc_fit'][2:4]])
                        elif frm == 'poly_fit_from_chr_nucpl':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_poly_from_chr, TC_TYPES_gene, AB, 4, 
                                                                 [k_fit[red_r[r]+'chr_fit'][2:4], 
                                                                  k_fit[red_r[r]+'nucpl_fit'][2:4], 
                                                                  k_fit[red_r[r]+'cyto_fit_from_chr_nucpl'][2:4]])
        #                     continue
                        elif frm == 'poly_fit_from_nuc':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_poly_from_nuc, TC_TYPES_gene, AB, 3, 
                                                                          [k_fit[red_r[r]+'nuc_fit'][2:4], 
                                                                           k_fit[red_r[r]+'cyto_fit_from_nuc'][2:4]])
        #                     continue
                        elif frm == 'tot_fit':
                            k_fit[red_r[r]+frm] = get_estimates_from_post(post_tot, TC_TYPES_gene, AB, 1)
                        else:
                            continue

                        logger.info('time_fit %s %s %s \n' % (red_r[r], frm, [1/k for k in k_fit[red_r[r]+frm][:4]]))
                else:
                    for frm in fracs_model[fr]:
                        k_fit[red_r[r]+frm] = [np.nan for i in range(5)]
            except (IndexError, KeyError, ZeroDivisionError):
                logger.warning('no data, so no nan rates')
                for frm in fracs_model[fr]:
                    k_fit[red_r[r]+frm] = [np.nan for i in range(5)]
logger.info('end')

### Visualize posterior

In [84]:
for fr in fracs:
    for frm in fracs_model[fr]:
        for r in reps[:2]:#[::-1]: 
            if red_r[r]+frm in k_fit.keys():
                k_out = k_fit[red_r[r]+frm]
                
                logger.info('time_fit %s %s %s \n' % (red_r[r], frm, [1/k for k in k_fit[red_r[r]+frm][:4]]))
                
                if not type(k_out[4])==float:#nan
                    post = k_out[4]
                    k_domain = np.geomspace(k_bound_lo, k_bound_hi, num=N_grid)

                    sns.set(style="whitegrid")
                    fig, ax = plt.subplots(figsize=(4,4))#inches
                    g = sns.lineplot(x=k_domain, y=post, color='b',
                                  #sizes=[5],# linewidth=2,alpha=1,linestyle='-',
                                  ax=ax,legend=False)
                    plt.axvline(x=k_out[2], color='r', linestyle='-', linewidth=2)
                    plt.axvline(x=k_out[0], color='g', linestyle='-', linewidth=2)
                    plt.axvline(x=k_out[1], color='b', linestyle='-', linewidth=2)
                    plt.axvline(x=k_out[3], color='r', linestyle='-', linewidth=2)
                    plt.xlabel('k from ' + frm)
                    plt.ylabel('posterior density function')
                    plt.title(red_r[r]+ ' '+gene)
                    g.set(xscale="log") 
                    plt.xlim([k_out[2]/2,k_out[3]*1.5])




In [33]:
!pip freeze

asteval==0.9.23
attrs==19.3.0
backcall==0.1.0
bleach==3.1.4
certifi==2021.5.30
charset-normalizer==2.0.1
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
docopt==0.6.2
entrypoints==0.3
future==0.18.2
goatools==1.1.6
gtfparse==1.2.1
idna==3.2
importlib-metadata==1.5.2
ipykernel==5.2.0
ipython==7.13.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.16.0
Jinja2==2.11.1
joblib==1.1.0
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.2
jupyter-console==6.1.0
jupyter-core==4.6.3
kiwisolver==1.1.0
llvmlite==0.36.0
lmfit==1.0.2
MarkupSafe==1.1.1
matplotlib==3.2.1
mistune==0.8.4
mpmath==1.2.1
nbconvert==5.6.1
nbformat==5.0.4
networkx==2.4
notebook==6.0.3
numba==0.53.1
numba-scipy==0.3.0
numpy==1.16.5
pandas==1.0.3
pandocfilters==1.4.2
parso==0.6.2
patsy==0.5.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.0.0
prometheus-client==0.7.1
prompt-toolkit==3.0.4
ptyprocess==0.6.0
pydot==1.4.2
Pygments==2.6.1
pyparsing==2.4.6
pyrsistent==0.16.0
python-dateutil==2.8.1
pytz==2019.3
pyzmq==19.0.0
qtco