# Setup

In [None]:
import os
import pickle

import numpy as np
import pandas as pd
import scipy

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import plotly.express as px

from tqdm.notebook import tqdm, trange

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
import random
from random import sample

random.seed(42) # set random seed for reproducibility

In [None]:
matplotlib.rcParams['pdf.fonttype'] = 42
plt.rcParams["figure.dpi"] = 300
sns.set_palette("deep")
sns.set_context("paper")
sns.set_style("whitegrid")

In [None]:
df_genes = pd.read_pickle('../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz')
df_genes.fillna(0, inplace=True)
df_genes = df_genes.sparse.to_dense().astype('int8')

display(
    df_genes.shape,
    df_genes.head()
)

In [None]:
metadata = pd.read_csv('../../data/metadata/mash_scrubbed_species_metadata.csv', index_col=0, dtype='object')

display(
    metadata.shape,
    metadata.head()
)

In [None]:
# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
df_eggnog = pd.read_csv('../../data/processed/df_eggnog.csv', index_col=0, dtype='object')

display(
    df_eggnog.shape,
    df_eggnog.head()
)

## useful functions

In [None]:
import scipy.sparse
from scipy.sparse import coo_matrix

def estimate_pan_core_size(df_genes, num_iter=10, log_batch=1):
    '''
    Computes pan/core genome size curves for many randomizations
    '''
    num_genes, num_strains = df_genes.shape
    
    #gene_data = sparse_arrays_to_sparse_matrix(df_genes)
    #gene_data = coo_matrix(df_genes)
    gene_data = df_genes.sparse.to_coo()
    
    gene_data = gene_data.T.tocsr() # now strain x cluster
    pan_genomes = np.zeros((num_iter, num_strains)) # estimated pan-genome curve per iteration
    core_genomes = np.zeros((num_iter, num_strains)) # estimated core-genome curve per iteration
    
    acc_genomes = np.zeros((num_iter, num_strains)) # estimated acc-genome curve per iteration
    rare_genomes = np.zeros((num_iter, num_strains)) # estimated rare-genome curve per iteration

    ''' Simulate pan/core-genomes for randomly ordered strains '''
    for i in trange(num_iter):
        if (i+1) % log_batch == 0:
            print('Iteration', i+1, 'of', num_iter)
        shuffle_indices = np.arange(num_strains)
        np.random.shuffle(shuffle_indices)
        gene_incidence = np.zeros(num_genes, dtype='int')
        for j,shuffle_col in enumerate(shuffle_indices):
            gene_incidence += gene_data[shuffle_col,:]
            #print(gene_incidence)
            pan_genomes[i,j] = (gene_incidence > 0).sum()
            core_genomes[i,j] = (gene_incidence >= (j+1)*0.987).sum()
            
            acc_genomes[i,j] = ((gene_incidence > j*0.037 ) & (gene_incidence < (j+1)*0.987)).sum()
            rare_genomes[i,j] = ((gene_incidence > 0 ) & (gene_incidence < j*0.037)).sum()

    ''' Save to DataFrame '''
    iter_index = map(lambda x: 'Iter' + str(x), range(1,num_iter+1))
    pan_cols = list(map(lambda x: 'Pan' + str(x), range(1,num_strains+1)))
    core_cols = list(map(lambda x: 'Core' + str(x), range(1,num_strains+1)))
    
    acc_cols = list(map(lambda x: 'Acc' + str(x), range(1,num_strains+1)))
    rare_cols = list(map(lambda x: 'Rare' + str(x), range(1,num_strains+1)))
    
    df_pan_core = pd.DataFrame(index=iter_index, columns=pan_cols + core_cols + acc_cols + rare_cols,
                               data=np.hstack([pan_genomes, core_genomes, acc_genomes, rare_genomes]))
    return df_pan_core


def fit_heaps(df_freqs):
    ''' Fits a single iteration to Heaps Law: PG size = kappa * (genes)^lambda_ '''
    heaps = lambda x, lambda_, kappa: kappa * np.power(x, lambda_)
    n_strains = df_freqs.shape[0]
    p0 = [0.5, float(min(df_freqs.values))]
    popt, pcov = scipy.optimize.curve_fit(heaps, 
        np.arange(1,n_strains+1), df_freqs.values, p0=p0)
    return popt


def fit_heaps_by_iteration(df_pan_core, section='pan'):
    ''' Fits Heaps Law to each iteration and returns lambda_ and kappa for each iteration '''
    #df = df_pan_core.iloc[:,:int(df_pan_core.shape[1]/2)].T
    if section.lower() == 'pan':
        df = df_pan_core[ [x for x in df_pan_core.columns if 'Pan' in x] ].T
    elif section.lower() == 'core':
        df = df_pan_core[ [x for x in df_pan_core.columns if 'Core' in x] ].T
    elif section.lower() == 'acc':
        df = df_pan_core[ [x for x in df_pan_core.columns if 'Acc' in x] ].T
    elif section.lower() == 'rare':
        df = df_pan_core[ [x for x in df_pan_core.columns if 'Rare' in x] ].T
    
    n_samples, n_iters = df.shape
    heaps_fits = {}
    for i,iter_label in enumerate(df.columns):
        lambda_, kappa = fit_heaps(df.iloc[:,i])
        heaps_fits[iter_label] = {'lambda_': lambda_, 'kappa': kappa}
    return pd.DataFrame.from_dict(heaps_fits, orient='index').reindex(df.columns)

# Heaps' Law Plot for CAR genomes

## Total (Complete + WGS)

In [None]:
# Generate sparse dataframe (needed for function to work)
df_genes_sparse = df_genes.astype(pd.SparseDtype("int8", 0))

# Estimate pan/core curve
df_pan_core = estimate_pan_core_size(df_genes_sparse, num_iter=20, log_batch=1) # generate pan/core size curves

In [None]:
df_pan_core.head()

In [None]:
output_pan = fit_heaps_by_iteration(df_pan_core, section='pan')
output_core = fit_heaps_by_iteration(df_pan_core, section='core')
output_acc = fit_heaps_by_iteration(df_pan_core, section='acc')
output_rare = fit_heaps_by_iteration(df_pan_core, section='rare')

In [None]:
# Heaps' Law coefficient
# This determines how open/closed your pangenome is
output_pan.lambda_.mean()

In [None]:
x = list(range(1, df_genes.shape[1]+1))

y_core = output_core.kappa.mean() * np.array(x) ** output_core.lambda_.mean()
y_acc = output_acc.kappa.mean() * np.array(x) ** output_acc.lambda_.mean()
y_rare = output_rare.kappa.mean() * np.array(x) ** output_rare.lambda_.mean()

In [None]:
# Log-linear plot
fig, ax = plt.subplots()

ax.stackplot(x, y_core, y_acc, y_rare)
ax.set_yscale('log')
ax.grid(False)
plt.show()

In [None]:
fig.savefig('../../data/figures/heaps_law_total.pdf', transparent=True)

## Complete seqs only

In [None]:
# Generate sparse dataframe (needed for function to work)
df_genes_complete_sparse = df_genes_complete.astype(pd.SparseDtype("int8", 0))

# Estimate pan/core curve
df_pan_core_complete = estimate_pan_core_size(df_genes_complete_sparse, num_iter=20, log_batch=1) # generate pan/core size curves

In [None]:
df_pan_core_complete.head()

In [None]:
output_pan_complete = fit_heaps_by_iteration(df_pan_core_complete, section='pan')
output_core_complete = fit_heaps_by_iteration(df_pan_core_complete, section='core')
output_acc_complete = fit_heaps_by_iteration(df_pan_core_complete, section='acc')
output_rare_complete = fit_heaps_by_iteration(df_pan_core_complete, section='rare')

In [None]:
# Heaps' Law coefficient
# This determines how open/closed your pangenome is
output_pan_complete.lambda_.mean()

In [None]:
x_complete = list(range(1, df_genes_complete.shape[1]+1))

y_core_complete = output_core_complete.kappa.mean() * np.array(x_complete) ** output_core_complete.lambda_.mean()
y_acc_complete = output_acc_complete.kappa.mean() * np.array(x_complete) ** output_acc_complete.lambda_.mean()
y_rare_complete = output_rare_complete.kappa.mean() * np.array(x_complete) ** output_rare_complete.lambda_.mean()

In [None]:
# Log-linear plot
fig, ax = plt.subplots()

ax.stackplot(x_complete, y_core_complete, y_acc_complete, y_rare_complete)
ax.set_yscale('log')
ax.grid(False)
plt.show()

In [None]:
fig.savefig('../../data/figures/heaps_law_complete.pdf', transparent=True)

In [None]:
# Log-linear plot
fig, ax = plt.subplots()

ax.stackplot(x_complete, y_core_complete, y_acc_complete, y_rare_complete)
ax.set_yscale('log')
ax.set_ylim(bottom = 1000)
ax.grid(False)
ax.text(10, 60000, "Heap's Coefficient: " + str(.480), style='italic')
ax.set_xlabel('Number of strains')
ax.set_ylabel('Number of Genes')
plt.savefig('../images/supplemental/heaps_law.svg', format='svg')
plt.show()

# Heaps' Law Plot, Stratified Rare Genome

In [None]:
# These thresholds come from notebook 3a
rg_threshold_total = 198/2709
rg_threshold_complete = 35/517

display(
    rg_threshold_total,
    rg_threshold_complete
)

In [None]:
import string
from multiprocessing import Pool

In [None]:
# Change rg_threshold to whatever your rare-genome threshold is (e.g. =0.05 for 5%)
# For L. plantarum, we used 15% (=0.15) as the cutoff for the rare/accessory genomes
# For E. coli, this number is 3.7% (using the cutoff method based on gene frequency)

# Generates multiple df_pan_core-like matrices for multiple iterations (can take a lot of time to run!)
def estimate_rare_categories_sizes(
    df_genes, df_eggnog, rg_threshold=0.07,
    cog_list=None, functional_list=None, num_iter=10, processes=-1
):
    # List of COG categories
    if cog_list is None:
        cog_list = list(string.ascii_uppercase)
        cog_list.append('-') # no COG category, no known orthologs
    
    # List of funcational categories
    if functional_list is None:
        functional_list = ['Prophage', 'Transposase']
        
    # Fit to Heap's law by simulating num_iter iterations for each COG & functional category
    with Pool(processes=processes) as p:
        args = [(df_genes, df_eggnog, rg_threshold, cog_list, functional_list) for x in range(num_iter)]
        results = p.starmap(func=_simulate_rare_categories, iterable=args)
    
    # Return dict
    results_dict = dict.fromkeys(cog_list+functional_list)
    
    # Format output
    for category in results_dict.keys():
        results_dict[category] = pd.DataFrame(
            data=[results[i][category] for i in range(num_iter)],
            index=[f'Iter{i+1}' for i in range(num_iter)],
            columns=[f'Rare{i+1}' for i in range(df_genes.shape[1])]
        )
    
    return results_dict


def _simulate_rare_categories(df_genes, df_eggnog, rg_threshold=0.07, cog_list=None, functional_list=None):
    '''Simulates a single iteration for rare genes by given COG or functional category.'''
    
    # Generate random strain order for iteration
    strain_order = sample(df_genes.columns.tolist(), k=len(df_genes.columns))
    curr_order = [] 
    
    # List of COG categories
    if cog_list is None:
        cog_list = list(string.ascii_uppercase)
        cog_list.append('-') # no COG category, no known orthologs
    
    # List of funcational categories
    if functional_list is None:
        functional_list = ['Prophage', 'Transposase']
    
    # Return dict
    categories_sim = dict.fromkeys(cog_list+functional_list)
    for key in categories_sim.keys():
        categories_sim[key] = np.zeros(df_genes.shape[1], dtype='int')
    
    # Get list of rare genes
    inRareGenome = df_genes.sum(axis=1) < rg_threshold*df_genes.shape[1]
    rare_genes = df_genes[inRareGenome].index.tolist()
    
    # Simulate Heaps' Plot size curve for each COG and functional category
    for i, strain in tqdm(enumerate(strain_order), total=len(strain_order)):
        curr_order = strain_order[:i+1]
        df_genes_curr_order = df_genes.loc[rare_genes, curr_order]
        
        inCurrentRareGenome = df_genes_curr_order.sum(axis=1) > 0
        curr_rare_genes = df_genes_curr_order[inCurrentRareGenome].index.tolist()
        
        # Get functional category sums (Part 1: for Prophage)
        if 'Prophage' in functional_list:
            value = (
                df_eggnog.loc[curr_rare_genes].Description.str.lower().str.contains('phage ').sum()
                + df_eggnog.loc[curr_rare_genes].Description.str.lower().str.contains('prophage').sum()
            )
            categories_sim['Prophage'][i] = value
        
        # Get functional category sums (Part 2: for everything else)
        curr_list = [x for x in functional_list if x != 'Prophage']
        
        def get_sum_func(category):
            return df_eggnog.loc[curr_rare_genes].Description.str.lower().str.contains(category.lower()).sum()
        
        results = list(map(get_sum_func, curr_list))
        
        for j, category in enumerate(curr_list):
            categories_sim[category][i] = results[j]
        
        # TODO: make category V also include all CRISPR + bacteriophage-resistance genes
        # TODO: remove double-counting (if possible) between COG and functional categories
        
        # Get COG category sums
        def get_sum_cog(category):
            return df_eggnog.loc[curr_rare_genes].COG_category.str.upper().str.contains(category.upper()).sum()
        
        results = list(map(get_sum_cog, cog_list))
        
        for j, category in enumerate(cog_list):
            categories_sim[category][i] = results[j]
    
    # Return dict of bootstrapped values for each rare-genome category
    return categories_sim