In [1]:
import numpy as np
import pandas as pd
from scipy import sparse,stats
import pickle
import statsmodels.api as sm
import warnings
import anndata

In [2]:
basepath = 'data/tasic/'
basepath_simulations = 'data/tasic/simulations/'
from readcount_tools import simulate_readcounts, broken_zeta

In [3]:
def zinb_fit(counts,seed=42,verbose=False):
    
    np.random.seed(seed)
    
    #fit and recover parameters
    zinb_fit = sm.ZeroInflatedNegativeBinomialP(counts,np.ones_like(counts)).fit_regularized(disp=verbose)
    loginflation_fit,logmu_fit,alpha_fit = zinb_fit.params
    theta_fit = 1/alpha_fit
    inflation_fit = np.exp(loginflation_fit)
    mu_fit = np.exp(logmu_fit)
    
    return mu_fit, theta_fit, inflation_fit, zinb_fit

def init_ad_for_fitting(ad,warning2text):
    
    #add fields to log results
    ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)
    ad.var['thetas_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)
    ad.var['inflations_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)
    
    #add fields to log warnings per gene
    for warning_type in warning2text.keys():
        key = f'{warning_type}_idx'
        ad.var[key] = np.zeros(ad.shape[1]).astype(bool)
    ad.var['no_warning_idx'] = np.zeros(ad.shape[1]).astype(bool)
    ad.var['other_warning_idx'] = np.zeros(ad.shape[1]).astype(bool)
    ad.var['cought_warnings']=['[]'] * ad.shape[1]
    ad.var['cought_errors']=[''] * ad.shape[1]
    
    #add fields to log other outcomes
    ad.var['converged']= np.zeros(ad.shape[1]).astype(bool)
    ad.var['n_iters']= np.nan * np.ones(ad.shape[1]).astype(bool)
    ad.var['warnflags_statsmodels']= [''] * ad.shape[1]

def print_warnings(ad,i,outcome_type,text,print_msg=True):
    gene_name = ad.var['genes'][i]
    gene_mean = ad.var['gene_mean_withinCluster'][i]
    if print_msg:
        print(f'\nfitting gene {i}/{ad.shape[1]}: {gene_name},'\
              f'mean={gene_mean:.2f} raised {outcome_type}: {text}')

def fit_zinb_within_cluster(ad_input,mean_cutoff = 5,save=True,verbose=False,filename='zinbfit.h5ad',show_warnings=False):

    warning2text = {'log_warning': 'invalid value encountered in log',
                'multiply_warning': 'invalid value encountered in multiply',
                'subtract_warning': 'invalid value encountered in subtract',
                'overflow_warning_exp': 'overflow encountered in exp',
                'overflow_warning_reduce': 'overflow encountered in reduce'}
    
    #filter out low expression genes
    ad = ad_input[:,ad_input.var['gene_mean_withinCluster'] > mean_cutoff]
    print(f'{ad.shape[1]} genes > {mean_cutoff}')
    
    #add fields to adata object to save results
    init_ad_for_fitting(ad,warning2text)

    for i,gene_idx in enumerate(ad.var.index):
        #print progress
        if np.mod(i,100)==0:
            print('.',end='')

        #capture warnings during fit but ignore pandas future warning
        with warnings.catch_warnings(record=True) as ws:
            warnings.filterwarnings("ignore", category=FutureWarning)
            gene_counts = ad[:,gene_idx].X.A.squeeze()

            #capture np.linalg.LinAlgErrors during fits
            try:
                mu_fit, theta_fit, inflation_fit, fit_object = zinb_fit(gene_counts,verbose=verbose)
                #save fit results
                ad.var.loc[gene_idx,'mus_fit']=mu_fit
                ad.var.loc[gene_idx,'thetas_fit']=theta_fit
                ad.var.loc[gene_idx,'inflations_fit']=inflation_fit
                #save fit stats/infos (fit object can't be saved in hdf5)
                ad.var.loc[gene_idx,'converged']=fit_object.mle_retvals['converged']
                ad.var.loc[gene_idx,'n_iters']=fit_object.mle_retvals['iterations']
                ad.var.loc[gene_idx,'warnflags_statsmodels']=fit_object.mle_retvals['warnflag']                
            except np.linalg.LinAlgError as error:
                ad.var.loc[gene_idx,'cought_errors'] = error.args[0]
                print_warnings(ad,i,outcome_type='np.linalg.LinAlgError',text=error.args[0],print_msg=show_warnings)
                
        #processing warnings of current gene
        cought_warnings_list = [w.message.args[0] for w in ws]
        if cought_warnings_list:
            unique_warnings = np.unique(cought_warnings_list)
            #print warning for all but the log warnings
            if not (len(unique_warnings)==1 and unique_warnings[0]=='invalid value encountered in log'):
                print_warnings(ad,i,outcome_type='warning(s)',text=cought_warnings_list,print_msg=show_warnings)
            #save warning type in flag
            n_known_warning_types = 0
            for warning_type in warning2text.keys():
                key = f'{warning_type}_idx'
                ad.var.loc[gene_idx,key] = warning2text[warning_type] in unique_warnings
                n_known_warning_types += int(ad.var.loc[gene_idx,key]) #add 1 if warning is known
            ad.var.loc[gene_idx,'other_warning_idx'] = n_known_warning_types < len(unique_warnings)
        else:
            ad.var.loc[gene_idx,'no_warning_idx'] = True

        #save fit warnings
        ad.var.loc[gene_idx,'cought_warnings']=str(cought_warnings_list)

    #save invalid/unexpected outcomes
    ad.uns['warning_types'] = np.array(list(warning2text.keys()) + ['other_warning'])
    ad.uns['warning2text'] = warning2text
    ad.var['no_error'] = ad.var.cought_errors == ''
    ad.var['valid_inflation_idx'] = ad.var['inflations_fit'] <= 1
    
    ad.write_h5ad(filename=filename)
    print(f'saved to {filename}')

### Fit ZINB to real data

In [4]:
mean_cutoff=5
adata_single_cluster = anndata.read_h5ad(f'{basepath}adata_single_cluster.h5ad')
print(adata_single_cluster.uns['clustername'],adata_single_cluster.shape)

clustername=adata_single_cluster.uns['clustername']

filename = f'{basepath_simulations}ZINBfit_single_cluster_meanCutoff_{mean_cutoff}.h5ad'
fit_zinb_within_cluster(adata_single_cluster,filename=filename)

L6 IT VISp Penk Col27a1 (1049, 33914)
11549 genes > 5
.

  ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)


...................................................................................................................saved to data/tasic/simulations/ZINBfit_single_cluster_meanCutoff_5.h5ad


### Grid search in broken zeta parameter space

In [5]:
#todo sync
def broken_zeta_w_stats_new(a1=1.4,
                a2=8.0,
                breakpoint=100,
                size=10000,
                z_max=100000,
                seed=42):
    
    params = dict(a1=a1,a2=a2,breakpoint=breakpoint,size=size,z_max=z_max,seed=seed)
    zs = broken_zeta(**params)
    mean=np.mean(zs)
    median=np.median(zs)
    var=np.var(zs)
    ff=var/mean
    empirical_alpha=mean+ff
    stats = dict(mean=mean,median=median,var=var,ff=ff,alpha=empirical_alpha)
    return params, stats

In [6]:
def grid_search_broken_zeta():
    i=0
    results=[]
    for a1 in np.arange(0.01,5,0.05):
        for a2 in np.arange(0.1,20,2.5):
            #require a2>a1 for observed shape
            if a2<=a1:
                continue
            for b in np.arange(1,100,5):
                i+=1
                param,stat=broken_zeta_w_stats_new(a1=a1,a2=a2,breakpoint=b)
                results.append(dict(**stat,**param))
    
    df_results = pd.DataFrame(results)
    
    with open(f'{basepath_simulations}gridsearch_broken_zeta','wb') as f:
        pickle.dump(df_results,f)
grid_search_broken_zeta()

# load zeta param search result
with open(f'{basepath_simulations}gridsearch_broken_zeta','rb') as f:
    df_grid_search_results = pickle.load(f)

In [7]:
def zeta_stat_str(zeta_params,amplification_stats):
    zeta_str = fr'alpha={amplification_stats["mean"]+amplification_stats["ff"]:.0f}, E[Z]={amplification_stats["mean"]:.0f}'
    return zeta_str

### Select parameter sets

In [8]:
def find_best_params(df,target_alpha=50,target_mean=20,max_total_deviation=5):

    df=df.copy()
    df['dev_from_target_alpha'] = np.abs(df['alpha']-target_alpha)
    df['dev_from_target_mean'] = np.abs(df['mean']-target_mean)
    df['total_deviation'] = df['dev_from_target_alpha'] + df['dev_from_target_mean']
    
    df = df[df['total_deviation']<max_total_deviation]
    df_params_sorted = df.sort_values(by='total_deviation')
        
    df_params_sorted['target_mean']=target_mean
    df_params_sorted['target_alpha']=target_alpha
    
    best_params = dict(a1=df_params_sorted.iloc[0,:]['a1'],
                   a2=df_params_sorted.iloc[0,:]['a2'],
                   breakpoint=df_params_sorted.iloc[0,:]['breakpoint'],
                   z_max=df_params_sorted.iloc[0,:]['z_max'],
                   constant=False)
    
    best_obs_mean=df_params_sorted.iloc[0,:]['mean']
    best_obs_ff=df_params_sorted.iloc[0,:]['ff']
    
    return best_params, best_obs_mean, best_obs_ff

In [9]:
zeta_params_list = []
zeta_params_info_list = []
target_alphas = [1,50,50,50]
target_means = [1,20,30,40]
set_z_constant = [True, False, False, False]

for z_constant,target_mean,target_alpha,color in zip(set_z_constant,target_means,target_alphas,['k','tab:blue','tab:orange','tab:green']):
    if not z_constant:
        best_params, best_obs_mean, best_obs_ff = find_best_params(df_grid_search_results,
                                                                   target_alpha=target_alpha,
                                                                   target_mean=target_mean)
        
        info = dict(target_mean=target_mean, target_alpha=target_alpha, target_ff=target_alpha-target_mean,
                    obs_mean=best_obs_mean,  obs_ff=best_obs_ff, obs_alpha=best_obs_mean+best_obs_ff,
                    color=color)
        
        print(f'For E[Z]={target_mean} and alpha={target_alpha} we choose: {best_params}\n(leading to E[Z]={best_obs_mean:.1f} and FF[Z]={best_obs_ff:.1f})\n')
    else:
        best_params = dict(a1=np.nan, a2=np.nan, breakpoint=np.nan, z_max=np.nan, constant=True)
        info = dict(target_mean=1, target_alpha=1, target_ff=0,
                    obs_mean=1, obs_ff=0, obs_alpha=1,
                    color=color)
        print(f'For E[Z]=1 and scale=1 we use the constant model Z=1')
    zeta_params_list.append(best_params)
    zeta_params_info_list.append(info)

For E[Z]=1 and scale=1 we use the constant model Z=1
For E[Z]=20 and alpha=50 we choose: {'a1': 0.9600000000000001, 'a2': 15.1, 'breakpoint': 91.0, 'z_max': 100000.0, 'constant': False}
(leading to E[Z]=19.6 and FF[Z]=30.2)

For E[Z]=30 and alpha=50 we choose: {'a1': 0.36000000000000004, 'a2': 5.1, 'breakpoint': 56.0, 'z_max': 100000.0, 'constant': False}
(leading to E[Z]=29.6 and FF[Z]=20.1)

For E[Z]=40 and alpha=50 we choose: {'a1': 0.01, 'a2': 17.6, 'breakpoint': 71.0, 'z_max': 100000.0, 'constant': False}
(leading to E[Z]=37.4 and FF[Z]=12.7)



### Simulate readcounts with Broken Zeta for the selected parameter sets and fit ZINB

In [10]:
def fit_multiple_zeta_simulations(zeta_param_sets,infos_zeta_param_sets,molsim_params,
                                  mean_cutoff = 5,tag='tag',show_warnings=False):
     
    for i,(zeta_params,info_param_set) in enumerate(zip(zeta_param_sets,infos_zeta_param_sets)):

        #simulate
        ad_sim, params_sim = simulate_readcounts(molsim_params,zeta_params,amplification_seed=42,
                                      tag=tag,color=info_param_set['color'])        
        ad_sim.uns['infos_zeta_param_set'] = info_param_set
        
        filename = f'{basepath_simulations}ZINBfit_simulation_{i}_meanCutoff_{mean_cutoff}.h5ad'
        
        #fit ZINB and save to file
        fit_zinb_within_cluster(ad_sim,mean_cutoff=mean_cutoff, filename=filename,show_warnings=show_warnings)

In [11]:
molsim_params = dict(n_cells=1000*100,
                      ps_input=np.logspace(-8,-1,num=25),
                      theta_molecules=10,
                      depth=100000,
                      seed=42)
fit_multiple_zeta_simulations(zeta_param_sets=zeta_params_list,
                              infos_zeta_param_sets=zeta_params_info_list,
                              molsim_params=molsim_params,
                              tag='fig6',
                              mean_cutoff=mean_cutoff)

removing 0 all-zero genes after simulation
12 genes > 5
.

  ad = anndata.AnnData(X=sparse.csc_matrix(readcounts_sim),layers=dict(molecules=molecules_sim))
  ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)


saved to data/tasic/simulations/ZINBfit_simulation_0_meanCutoff_5.h5ad
removing 0 all-zero genes after simulation


tcmalloc: large alloc 16377569280 bytes == 0x1d98c000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x7c04a0000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x7c04a0000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x7c04a0000 @ 


Broken Zeta amplification with {'a1': 0.9600000000000001, 'a2': 15.1, 'breakpoint': 91.0, 'z_max': 100000.0, 'constant': False}
Effectively amplifying with Zs that have mean=20.1, median=8.0, var=609.9, FF=30.3, leading to alpha=50.4


  ad = anndata.AnnData(X=sparse.csc_matrix(readcounts_sim),layers=dict(molecules=molecules_sim))
  ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)


16 genes > 5
.saved to data/tasic/simulations/ZINBfit_simulation_1_meanCutoff_5.h5ad
removing 0 all-zero genes after simulation


tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x1a422000 @ 
tcmalloc: large alloc 16377569280 bytes == 0xb9177e000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x1a422000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 


Broken Zeta amplification with {'a1': 0.36000000000000004, 'a2': 5.1, 'breakpoint': 56.0, 'z_max': 100000.0, 'constant': False}
Effectively amplifying with Zs that have mean=30.2, median=25.0, var=639.1, FF=21.2, leading to alpha=51.4


  ad = anndata.AnnData(X=sparse.csc_matrix(readcounts_sim),layers=dict(molecules=molecules_sim))
  ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)


17 genes > 5
.saved to data/tasic/simulations/ZINBfit_simulation_2_meanCutoff_5.h5ad
removing 0 all-zero genes after simulation


tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0xb9177e000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x1a422000 @ 
tcmalloc: large alloc 16377569280 bytes == 0xb9177e000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 
tcmalloc: large alloc 16377569280 bytes == 0x3ef1c2000 @ 


Broken Zeta amplification with {'a1': 0.01, 'a2': 17.6, 'breakpoint': 71.0, 'z_max': 100000.0, 'constant': False}
Effectively amplifying with Zs that have mean=37.9, median=38.0, var=478.7, FF=12.6, leading to alpha=50.5


  ad = anndata.AnnData(X=sparse.csc_matrix(readcounts_sim),layers=dict(molecules=molecules_sim))
  ad.var['mus_fit']=np.nan * np.ones(ad.shape[1]).astype(bool)


17 genes > 5
.saved to data/tasic/simulations/ZINBfit_simulation_3_meanCutoff_5.h5ad


In [12]:
with open(f'{basepath_simulations}zeta_params_list.pickle','wb') as f:
    pickle.dump(zeta_params_list,f)
with open(f'{basepath_simulations}zeta_params_info_list.pickle','wb') as f:
    pickle.dump(zeta_params_info_list,f)