In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import glob
from scipy.special import exp1,factorial
from scipy.stats import binom,nbinom
from scipy.special import gammaln
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
from scipy.interpolate import griddata



In [2]:
### Functions for theory

def get_lc(sigma,s):
    return np.sqrt(sigma**2/s)

def get_lambda_theory(w,sigma,s):
    lc = get_lc(sigma,s)
    term = (w / lc) ** 2
    if term <= 800:
        prod_term = np.exp(term) * exp1(term)
    else:
        prod_term = sum((factorial(k) / (-term)**k for k in range(7))) / term
    return (4*np.pi)/prod_term

def get_EP_theory(mu, s):
    return mu / s

def get_EPsquared_theory(mu, s, rho, sigma, w):
    lcs = get_lc(sigma, s)**2
    term = (w / np.sqrt(lcs)) ** 2
    if term <= 800:
        prod_term = np.exp(term) * exp1(term)
    else:
        prod_term = sum((factorial(k) / (-term)**k for k in range(7))) / term
    return (mu / (s ** 2 * rho * 4 * np.pi * lcs)) * prod_term + mu ** 2 / s ** 2

def get_sfs_theory(x,n,mu,s,rho,sigma,w):
    mean = get_EP_theory(mu,s)
    var = get_EPsquared_theory(mu,s,rho,sigma,w) - mean**2
    alpha = mean**2/var
    beta = mean/var
    return nbinom.pmf(x,alpha,beta/(beta+n))

def get_sfs_theory_unif(x,n,mu,s,N):
    mean = mu/s
    var = mu/(s*s*N)
    alpha = mean**2/var
    beta = mean/var
    return nbinom.pmf(x,alpha,beta/(beta+n))

### Comparison between theory and data

In [3]:
### Scale factors for y axis - from empiriaclplots_v20250217

len_syn = 1308.0216666666665
len_mis = 2616.043333333333
len_lof = 167.616

prop_kept_syn = 32320/247378
prop_kept_mis = 32320/505963
prop_kept_lof = 1

In [4]:
### Function 

def get_obs_dist(width,vartype,scale_factor,centers): # return counts: scale factor should be equal to prop_kept
    obs_dist_list = []
    if width is not None: # not uniform
        for center in centers: # average over centers
            sfs = pd.read_csv(f'../results/sfs_freq/chr1_{vartype}_{center}geo{width}_nSIR10000_nSIRreps10.SIRfreq.sfs', sep='\t')
            sfs_grouped = sfs.groupby('COUNT')['SFS_COUNTS'].agg(['mean', 'std']) # average over replicates
            obs_dist_temp = sfs_grouped['mean'] / scale_factor # scale to adjust for downsampling
            obs_dist_list.append(obs_dist_temp)
        obs_dist = np.mean(obs_dist_list,axis=0)
    else: # uniform
        sfs = pd.read_csv(f'../results/sfs_freq/chr1_{vartype}_uniformgeo_nSIR10000_nSIRreps10.SIRfreq.sfs', sep=' ')
        sfs_grouped = sfs.groupby('COUNT')['SFS_COUNTS'].agg(['mean', 'std']) # average over replicates
        obs_dist = sfs_grouped['mean'] / scale_factor # scale to adjust for downsampling
    return obs_dist

### Function to compare observed (data) to expected (theory) sfs and return log likelihood

def comp_logl(nb_dist,width,vartype,scale_factor,centers=['centerE9N9','centerE16N4','centerE6N4'],min_x=1,max_x=10):
    obs_dist = get_obs_dist(width,vartype,scale_factor,centers)
    obs_dist = obs_dist[min_x:max_x+1]
    # log likelihood calculation log L = sum(OlogE-E-logO!)
    logl = np.sum(obs_dist * np.log(nb_dist) - nb_dist - gammaln(obs_dist + 1))  #log gamma is numerically stable implementation of log(O!)
    return logl

### Function for grid search 
def grid_search_multiw(n,sigma_list,rho_list,s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,max_x=10,min_x=1):
    assert len(w_list) == len(widths), "w_list and width_list must be the same length"
    logl_results = []
    for sigma in sigma_list:
        for rho in rho_list:
            for s in s_list:
                for mu in mu_list:
                    logl_sum = 0
                    for i,w in enumerate(w_list):
                        # nb dist needs to be counts - theory returns bp - convert to kb with *1000 and then to length of appropriate region with L_scale (in kb)
                        nb_dist = np.array([get_sfs_theory(y, n, mu, s, rho, sigma, w)*L_scale*1000 for y in np.arange(min_x, max_x+1)])
                        logl_sum += comp_logl(nb_dist,widths[i],vartype,scale_factor,min_x=min_x,max_x=max_x)
                    logl_results.append((sigma,rho,s,'multi',mu,logl_sum))
    logl_results.sort(key=lambda x:x[5],reverse=True)
    return logl_results

### Function to print grid search results

def print_res(results,vt,max_rank=20):
    print(f"Ranking of parameter combinations by best fit (highest logL), {vt} variants:")
    for rank, (sigma, rho, s, w, mu, logl) in enumerate(results, 1):
        if rank<=max_rank:
            print(f"{rank}. sigma={sigma}, rho={rho}, s={s:.1e}, w={w}, mu={mu:.2e}, logl={logl:.5f}")

### Plotting function -sfs

def plot_data_vs_theory(sigma, rho, s, mu=1.25e-8, n=10000, centers=['centerE9N9','centerE16N4','centerE6N4'],
                        widths_data=['50000','100000','150000',None], vartypes=['synonymous','missense','lof'],
                        scale_factors=[prop_kept_syn, prop_kept_mis, 1], maxval=50):
    colors = ['steelblue', 'orchid', 'darkorange']
    widths_theory = [50, 100, 150, 500]
    Lscale = [len_syn, len_mis, len_lof]
    
    fig, ax = plt.subplots(1, 4, figsize=(20, 5))
    x = np.logspace(0, 2)
    y = x**-1

    for i in [0, 1, 2, 3]:
        ax[i].loglog(x, y, color='lightgray', linestyle='--')
        ax[i].loglog(x, y * 10, color='lightgray', linestyle='--')
        ax[i].loglog(x, y * 100, color='lightgray', linestyle='--')
        ax[i].loglog(x, y / 100, color='lightgray', linestyle='--')
        ax[i].loglog(x, y / 10, color='lightgray', linestyle='--')

    for j, w in enumerate(widths_theory):
        nb_dist = [get_sfs_theory(y, n, mu, s, rho, sigma, w) * 1000 for y in np.arange(0, maxval)] # plot per kb
        ax[j].loglog(np.arange(0, maxval), nb_dist, marker=None, linestyle='-', linewidth=1, alpha=1, color='black', label=f'theory')

    for j, wid in enumerate(widths_data):
        for i, vt in enumerate(vartypes):
            sfs = get_obs_dist(wid, vt, scale_factors[i], centers=centers)
            sfs_kb = [x / Lscale[i] for x in sfs] # plot per kb
            ax[j].loglog(np.arange(0, maxval), sfs_kb[:maxval], marker='x', color=colors[i], linestyle='', label=vartypes[i])

    for i in [0, 1, 2, 3]:
        ax[i].set_ylim(1e-3, 1e2)
        ax[i].legend()

    ax[0].set_title('w=50km')
    ax[1].set_title('w=100km')
    ax[2].set_title('w=150km')
    ax[3].set_title('uniform')

    plt.show()

## Plotting function - heatmap

def plot_log_likelihood(res_df):
    res_df['rho'] = res_df['rho'].round(2)
    res_df['sigma'] = res_df['sigma'].round(2)
    
    filtered_data = res_df.loc[res_df.groupby(['sigma', 'rho'])['logl'].idxmax()]
    heatmap_data = filtered_data.pivot(index='rho', columns='sigma', values='logl')
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_data, annot=False, cmap="plasma", fmt=".2f", annot_kws={"size": 6})
    plt.xlabel("Sigma")
    plt.ylabel("Rho")
    plt.title('log likelihood - Poisson')
    plt.show()



def plot_log_likelihood_contour(res_df):
    res_df['rho'] = res_df['rho'].round(2)
    res_df['sigma'] = res_df['sigma'].round(2)
    
    filtered_data = res_df.loc[res_df.groupby(['sigma', 'rho'])['logl'].idxmax()]
    
    # Prepare the grid for contour plot
    rho_vals = np.unique(filtered_data['rho'])
    sigma_vals = np.unique(filtered_data['sigma'])
    
    # Create a meshgrid for contour plot
    Rho, Sigma = np.meshgrid(rho_vals, sigma_vals)
    LogL = filtered_data.pivot(index='rho', columns='sigma', values='logl').values
    
    # Define contour levels
    logl_min = LogL.min()
    logl_max = LogL.max()
    contour_levels = np.linspace(logl_min, logl_max, 20)
    
    plt.figure(figsize=(8, 6))
    cp = plt.contourf(Sigma, Rho, LogL, cmap="plasma", levels=contour_levels)
    plt.colorbar(cp, label='Log Likelihood')
    plt.xlabel("Sigma")
    plt.ylabel("Rho")
    plt.title('Log Likelihood Contour Plot - Poisson')
    
    # Flip the y-axis (Rho) order
    plt.gca().invert_yaxis()
    
    plt.show()


In [8]:
# # test case 
# widths = ['50000','100000','150000']
# mu_list = [1.25e-8]
# w_list = [50,100,150]
# vartype = 'missense'
# sigma_list = np.linspace(1,500,1)#(1,30,30)
# rho_list = np.linspace(0.0001,7.5,1)
# s_list = np.logspace(-4,-1,1)
# scale_factor = prop_kept_mis
# L_scale = len_mis
# n=20000 # 2*num indiv

# res3_minx1 = grid_search_multiw(n,sigma_list,rho_list,s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=1,max_x=10)
# res3_minx1 = pd.DataFrame(res3_minx1,columns=['sigma','rho','s','w','mu','logl'])

In [9]:
# res3_minx1

In [None]:
# actual case

In [25]:
widths = ['50000','100000','150000']
mu_list = [1.25e-8]#,2.5e-8]
w_list = [50,100,150]
vartype = 'missense'
sigma_list = np.linspace(1,500,50)
rho_list = np.linspace(0.0001,7.5,50)
s_list = np.logspace(-4,-1,40)
scale_factor = prop_kept_mis
L_scale = len_mis
n=20000 # 2*num indiv

In [12]:
res3_minx1 = grid_search_multiw(n,sigma_list,rho_list,s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=1,max_x=10)
res3_minx1 = pd.DataFrame(res3_minx1,columns=['sigma','rho','s','w','mu','logl'])
res3_minx1.to_csv("res_w50_w100_w150_minx1_sigma500_rho7.5_mu1.25_sfs_freq.csv", index=False)

In [26]:
res3_minx2 = grid_search_multiw(n,sigma_list,rho_list,s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=2,max_x=10)
res3_minx2 = pd.DataFrame(res3_minx2,columns=['sigma','rho','s','w','mu','logl'])
res3_minx2.to_csv("res_w50_w100_w150_minx2_sigma500_rho7.5_mu1.25_sfs_freq.csv", index=False)

In [27]:
res3_minx3 = grid_search_multiw(n,sigma_list,rho_list,s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=3,max_x=10)
res3_minx3 = pd.DataFrame(res3_minx3,columns=['sigma','rho','s','w','mu','logl'])
res3_minx3.to_csv("res_w50_w100_w150_minx3_sigma500_rho7.5_mu1.25_sfs_freq.csv", index=False)

In [14]:
res3_minx1

Unnamed: 0,sigma,rho,s,w,mu,logl
0,21.367347,1.530692,0.007017,multi,1.250000e-08,-1.153108e+03
1,31.551020,0.918455,0.007017,multi,1.250000e-08,-1.403923e+03
2,72.285714,0.306218,0.007017,multi,1.250000e-08,-1.952235e+03
3,133.387755,0.153159,0.008377,multi,1.250000e-08,-2.813136e+03
4,500.000000,0.000100,0.000100,multi,1.250000e-08,-2.553798e+04
...,...,...,...,...,...,...
99995,489.816327,7.346941,0.000100,multi,1.250000e-08,-6.668903e+06
99996,500.000000,7.193882,0.000100,multi,1.250000e-08,-6.669055e+06
99997,489.816327,7.500000,0.000100,multi,1.250000e-08,-6.669090e+06
99998,500.000000,7.346941,0.000100,multi,1.250000e-08,-6.669243e+06


### LoFs

In [28]:
# sigma_list = np.linspace(1,30,30)
# rho_list = np.linspace(0.1,10,30)
s_list = np.logspace(-4,-1,40)
mu_list = [1.25e-8]#,2.5e-8
widths = ['50000','100000','150000']
w_list = [50,100,150]
vartype = 'lof'
scale_factor = prop_kept_lof
L_scale = len_lof
n=20000 # 2*num indiv

In [19]:
lof_min1 = grid_search_multiw(n,[21.367347],[1.530692],s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=1,max_x=10)
lof_min1_df = pd.DataFrame(lof_min1,columns=['sigma','rho','s','w','mu','logl'])

In [20]:
lof_min1_df

Unnamed: 0,sigma,rho,s,w,mu,logl
0,21.367347,1.530692,0.01,multi,1.25e-08,-165.986218
1,21.367347,1.530692,0.008377,multi,1.25e-08,-214.377468
2,21.367347,1.530692,0.011938,multi,1.25e-08,-222.816695
3,21.367347,1.530692,0.014251,multi,1.25e-08,-374.396145
4,21.367347,1.530692,0.007017,multi,1.25e-08,-380.781344
5,21.367347,1.530692,0.017013,multi,1.25e-08,-612.30147
6,21.367347,1.530692,0.005878,multi,1.25e-08,-680.601948
7,21.367347,1.530692,0.020309,multi,1.25e-08,-929.911727
8,21.367347,1.530692,0.004924,multi,1.25e-08,-1132.183254
9,21.367347,1.530692,0.024245,multi,1.25e-08,-1322.176017


In [21]:
lof_min1_df.to_csv('lof_min1_sfs_freq.csv',index=False)

In [29]:
lof_min2 = grid_search_multiw(n,[21.367347],[1.224573],s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=2,max_x=10)
lof_min2_df = pd.DataFrame(lof_min2,columns=['sigma','rho','s','w','mu','logl'])
lof_min2_df.to_csv('lof_min2_sfs_freq.csv',index=False)

In [30]:
lof_min3 = grid_search_multiw(n,[51.918367],[0.306218],s_list,w_list,mu_list,widths,vartype,scale_factor,L_scale,min_x=3,max_x=10)
lof_min3_df = pd.DataFrame(lof_min3,columns=['sigma','rho','s','w','mu','logl'])
lof_min3_df.to_csv('lof_min3_sfs_freq.csv',index=False)