# Setup

In [None]:
# Base imports
import os
import pickle

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'
# ML import
#from sklearn.decomposition import NMF
#from sklearn.metrics import mean_squared_error, median_absolute_error

In [None]:
df_genes = pd.read_pickle('../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz')
df_genes.fillna(0, inplace=True)
df_genes = df_genes.sparse.to_dense().astype('int8')

display(
    df_genes.shape,
    df_genes.head()
)

In [None]:
metadata = pd.read_csv('../../data/metadata/mash_scrubbed_species_metadata.csv', index_col=0, dtype='object')

display(
    metadata.shape,
    metadata.head()
)

In [None]:
# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id]
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape    

## useful functions

In [None]:
# Function for seperating Core-, Accessory-, & Rare- genomes

def find_pangenome_segments(df_genes, threshold=0.1, ax=None):
    '''
    Computes the gene frequency thresholds at which a gene can be categorized as 
    core, accessory, or unique. Specifically, models the gene frequency distribution
    as the sum of two power laws (one flipped), and fits the CDF to a five-parameter
    function dervied from those power laws. Also identifies the inflection point and
    the core and unique extremes relative to the inflection point and threshold.
    
          PMF(x;c1,c2,a1,a2) ~ c1 * x^-a1 + c2 * (n-x)^-a2
        CDF(x;c1,c2,a1,a2,k) ~ c1/(1-a1) * x^(1-a1) - c2/(1-a2) * (n-x)^(1-a2) + k
        
    Where x = frequency, n = maximum frequency + 1, other variables are parameters.
    
    Pangenome segments example at 10%:
    - N = total strains, R = computed inflection point
    - Core: Observed in >= R + (1 - 0.1) * (N-R) strains
    - Unique: Observed in <= 0.1 * R strains
    - Accessory: Everything in between
    
    Parameters
    ----------
    df_genes : pd.DataFrame
        Binary gene x strain table.
    threshold : float
        Proximity to each frequency extreme compared to inflection point
        that determines if a gene is core, unique, or accessory (default 0.1)
    ax : plt.axes
        If provided, plots pangenome frequency CDF with segments (default None)
        
    Returns
    -------
    segments : tuple
        2-tuple with (min core limit, max unique limit), not rounded.
    popt : tuple
        5-tuple with fitted CDF parameters (c1,c2,a1,a2,k). Note that
        c1, c2, and k are scaled relative to the number of unique genes
    r_squared : float
        R^2 between fit and observed cumulative gene frequency distribution
    ax : plt.axes
        If ax is not None, returns axis with plots
    '''

    ''' Computing gene frequencies and frequency counts '''
    if type(df_genes) == pd.DataFrame: # data frame provided
        df_gene_freq = df_genes.fillna(0).sum(axis=1)
    else: # array provided
        df_gene_freq = pd.Series(data=df_genes.sum(axis=1), index=map(lambda x: 'G' + str(x), range(df_genes.shape[0])))
    df_freq_counts = df_gene_freq.value_counts()
    if 0 in df_freq_counts.index: # filter out unobserved genes (i.e. when subsetting genomes)
        df_freq_counts.drop(index=0, inplace=True)
    df_freq_counts = df_freq_counts[sorted(df_freq_counts.index)]
    cumulative_frequencies = np.cumsum(df_freq_counts.values)
    frequency_bins = np.array(df_freq_counts.index)
    
    ''' Fitting CDF '''
    X = frequency_bins.astype(float)
    Y = cumulative_frequencies.astype(float)
    n = max(frequency_bins) + 1
    dual_power_cdf = lambda x,c1,c2,a1,a2,k: \
        Y[0]*(c1*np.power(x,1.0-a1)/(1.0-a1) - c2*np.power(n-x,1.0-a2)/(1.0-a2) + k)
    p0 = [1.0,1.0,2.0,2.0,1.0]
    bounds = ([0.0,0.0,1.0,1.0,0.0],[np.inf,np.inf,np.inf,np.inf,Y[-1]/Y[0]])
    popt, pcov = scipy.optimize.curve_fit(dual_power_cdf, X, Y, p0=p0, bounds=bounds, maxfev=100000)
    
    ''' Extracting inflection point of CDF and frequency thresholds '''
    dual_power_pdf = lambda x,c1,c2,a1,a2: Y[0]*(c1*np.power(x,-a1) + c2*np.power(n-x,-a2))
    dual_power_pdf_fit = lambda x: dual_power_pdf(x,*popt[:4]) # minimize PMF
    res = scipy.optimize.minimize_scalar(dual_power_pdf_fit, method='bounded', bounds=[1,n-1])
    inflection_freq = res.x # inflection point x, i.e. frequency threshold 
    unique_strains_max = inflection_freq * threshold
    core_strains_min = inflection_freq + (n - 1 - inflection_freq) * (1.0 - threshold)
    segments = (core_strains_min, unique_strains_max)
    
    ''' Curve fit evaluation: R^2 and MAE
        R^2 technically invalid for nonlinear fits but commonly reported
        MAE is more relevant for judging nonlinear models '''
    Yfit = np.array(list(map(lambda x: dual_power_cdf(x,*popt), X))) # fitted CDF
    SStot = np.sum(np.square(Y - Y.mean()))
    SSres = np.sum(np.square(Y - Yfit))
    r_squared = 1 - (SSres/SStot)
    # rmse = np.sqrt(np.square(Y - Yfit).mean())
    mae = np.abs(Y - Yfit).mean()
    
    ''' Optionally, generating plot '''
    if ax:
        ax.plot(X, Y, label='observed')
        ax.plot(X, Yfit, label='fit', ls='--')
        ax.scatter([inflection_freq], [dual_power_cdf(inflection_freq,*popt)], 
                   label='inflection point', color='black', alpha=0.7)
        ax.axvline(unique_strains_max, ls='--', color='k')
        ax.axvline(core_strains_min, ls='--', color='k')
        ax.axvline(inflection_freq, ls='--', color='lightgray')
        
        unique_rounded = int(unique_strains_max) + 1
        core_rounded = int(core_strains_min)
        unique_text = 'Unique:\n<' + str(unique_rounded)
        core_text = 'Core:\n>' + str(core_rounded)
        r2_text = 'R^2=' + str(np.round(r_squared,3))
        mae_text = 'MAE=' + str(np.round(mae,2))
        ax.text(unique_strains_max + n*0.02, Y[0], unique_text, ha='left', va='bottom')
        ax.text(core_strains_min - n*0.02, Y[0], core_text, ha='right', va='bottom')
        ax.text(unique_strains_max + n*0.1, 0.95*Y[-1], r2_text, ha='left', va='bottom')
        ax.text(unique_strains_max + n*0.1, 0.95*Y[-1], mae_text, ha='left', va='top')
        ax.set_xlabel('Gene frequency')
        ax.set_ylabel('Cumulative genes')
        return segments, popt, r_squared, mae, ax
    else:
        return segments, popt, r_squared, mae# Function for seperating Core-, Accessory-, & Rare- genomes

def find_pangenome_segments(df_genes, threshold=0.1, ax=None):
    '''
    Computes the gene frequency thresholds at which a gene can be categorized as 
    core, accessory, or unique. Specifically, models the gene frequency distribution
    as the sum of two power laws (one flipped), and fits the CDF to a five-parameter
    function dervied from those power laws. Also identifies the inflection point and
    the core and unique extremes relative to the inflection point and threshold.
    
          PMF(x;c1,c2,a1,a2) ~ c1 * x^-a1 + c2 * (n-x)^-a2
        CDF(x;c1,c2,a1,a2,k) ~ c1/(1-a1) * x^(1-a1) - c2/(1-a2) * (n-x)^(1-a2) + k
        
    Where x = frequency, n = maximum frequency + 1, other variables are parameters.
    
    Pangenome segments example at 10%:
    - N = total strains, R = computed inflection point
    - Core: Observed in >= R + (1 - 0.1) * (N-R) strains
    - Unique: Observed in <= 0.1 * R strains
    - Accessory: Everything in between
    
    Parameters
    ----------
    df_genes : pd.DataFrame
        Binary gene x strain table.
    threshold : float
        Proximity to each frequency extreme compared to inflection point
        that determines if a gene is core, unique, or accessory (default 0.1)
    ax : plt.axes
        If provided, plots pangenome frequency CDF with segments (default None)
        
    Returns
    -------
    segments : tuple
        2-tuple with (min core limit, max unique limit), not rounded.
    popt : tuple
        5-tuple with fitted CDF parameters (c1,c2,a1,a2,k). Note that
        c1, c2, and k are scaled relative to the number of unique genes
    r_squared : float
        R^2 between fit and observed cumulative gene frequency distribution
    ax : plt.axes
        If ax is not None, returns axis with plots
    '''

    ''' Computing gene frequencies and frequency counts '''
    if type(df_genes) == pd.DataFrame: # data frame provided
        df_gene_freq = df_genes.fillna(0).sum(axis=1)
    else: # array provided
        df_gene_freq = pd.Series(data=df_genes.sum(axis=1), index=map(lambda x: 'G' + str(x), range(df_genes.shape[0])))
    df_freq_counts = df_gene_freq.value_counts()
    if 0 in df_freq_counts.index: # filter out unobserved genes (i.e. when subsetting genomes)
        df_freq_counts.drop(index=0, inplace=True)
    df_freq_counts = df_freq_counts[sorted(df_freq_counts.index)]
    cumulative_frequencies = np.cumsum(df_freq_counts.values)
    frequency_bins = np.array(df_freq_counts.index)
    
    ''' Fitting CDF '''
    X = frequency_bins.astype(float)
    Y = cumulative_frequencies.astype(float)
    n = max(frequency_bins) + 1
    dual_power_cdf = lambda x,c1,c2,a1,a2,k: \
        Y[0]*(c1*np.power(x,1.0-a1)/(1.0-a1) - c2*np.power(n-x,1.0-a2)/(1.0-a2) + k)
    p0 = [1.0,1.0,2.0,2.0,1.0]
    bounds = ([0.0,0.0,1.0,1.0,0.0],[np.inf,np.inf,np.inf,np.inf,Y[-1]/Y[0]])
    popt, pcov = scipy.optimize.curve_fit(dual_power_cdf, X, Y, p0=p0, bounds=bounds, maxfev=100000)
    
    ''' Extracting inflection point of CDF and frequency thresholds '''
    dual_power_pdf = lambda x,c1,c2,a1,a2: Y[0]*(c1*np.power(x,-a1) + c2*np.power(n-x,-a2))
    dual_power_pdf_fit = lambda x: dual_power_pdf(x,*popt[:4]) # minimize PMF
    res = scipy.optimize.minimize_scalar(dual_power_pdf_fit, method='bounded', bounds=[1,n-1])
    inflection_freq = res.x # inflection point x, i.e. frequency threshold 
    unique_strains_max = inflection_freq * threshold
    core_strains_min = inflection_freq + (n - 1 - inflection_freq) * (1.0 - threshold)
    segments = (core_strains_min, unique_strains_max)
    
    ''' Curve fit evaluation: R^2 and MAE
        R^2 technically invalid for nonlinear fits but commonly reported
        MAE is more relevant for judging nonlinear models '''
    Yfit = np.array(list(map(lambda x: dual_power_cdf(x,*popt), X))) # fitted CDF
    SStot = np.sum(np.square(Y - Y.mean()))
    SSres = np.sum(np.square(Y - Yfit))
    r_squared = 1 - (SSres/SStot)
    # rmse = np.sqrt(np.square(Y - Yfit).mean())
    mae = np.abs(Y - Yfit).mean()
    
    ''' Optionally, generating plot '''
    if ax:
        ax.plot(X, Y, label='observed')
        ax.plot(X, Yfit, label='fit', ls='--')
        ax.scatter([inflection_freq], [dual_power_cdf(inflection_freq,*popt)], 
                   label='inflection point', color='black', alpha=0.7)
        ax.axvline(unique_strains_max, ls='--', color='k')
        ax.axvline(core_strains_min, ls='--', color='k')
        ax.axvline(inflection_freq, ls='--', color='lightgray')
        
        unique_rounded = int(unique_strains_max) + 1
        core_rounded = int(core_strains_min)
        unique_text = 'Unique:\n<' + str(unique_rounded)
        core_text = 'Core:\n>' + str(core_rounded)
        r2_text = 'R^2=' + str(np.round(r_squared,3))
        mae_text = 'MAE=' + str(np.round(mae,2))
        ax.text(unique_strains_max + n*0.02, Y[0], unique_text, ha='left', va='bottom')
        ax.text(core_strains_min - n*0.02, Y[0], core_text, ha='right', va='bottom')
        ax.text(unique_strains_max + n*0.1, 0.95*Y[-1], r2_text, ha='left', va='bottom')
        ax.text(unique_strains_max + n*0.1, 0.95*Y[-1], mae_text, ha='left', va='top')
        ax.set_xlabel('Gene frequency')
        ax.set_ylabel('Cumulative genes')
        return segments, popt, r_squared, mae, ax
    else:
        return segments, popt, r_squared, mae

# CAR genomes

## Total pangenome curve (both complete + WGS sequences)

In [None]:
df_gene_freq = df_genes.sum(axis=1)

fig, ax = plt.subplots()
sns.histplot(df_gene_freq, binwidth=50, ax=ax)
plt.yscale('log')
plt.show()

In [None]:
fig, ax = plt.subplots()

segments, popt, r_squared, mae, ax = find_pangenome_segments(df_genes, threshold=0.1, ax=ax)

In [None]:
df_freq = df_genes.sum(axis=1)

df_core = df_genes[df_freq > np.floor(segments[0])]
df_rare = df_genes[df_freq < np.ceil(segments[1])]

acc_gene_list = list(set(df_genes.index)
                     - set(df_core.index)
                     - set(df_rare.index)
                    )

df_acc = df_genes.loc[acc_gene_list].copy()

display(
    df_core.shape,
    df_acc.shape,
    df_rare.shape
)

## Complete sequences only (needed for NMF)

In [None]:
def bin_sizes(x, thresh):
    opt_width = (x.max() - x.min()) / (len(x) ** (1 / 3))

    # Width calculated using optimal width and iModulon threshold
    if thresh > opt_width:
        width = thresh / int(thresh / opt_width)
    else:
        width = thresh / 2

    # Use width and thresh to calculate xmin, xmax
    if x.min() < -thresh:
        multiple = np.ceil(abs(x.min() / width))
        xmin = -(multiple + 1) * width
    else:
        xmin = -(thresh + width)

    if x.max() > thresh:
        multiple = np.ceil(x.max() / width)
        xmax = (multiple + 1) * width
    else:
        xmax = thresh + width

    return np.arange(xmin, xmax + width, width)

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
ax.set_title("Enterobacter Cumulative Gene Distribution")

segments, popt, r_squared, mae, ax = find_pangenome_segments(df_genes_complete, threshold=0.1, ax=ax)

plt.rcParams['svg.fonttype'] = 'none'
plt.savefig('../images/gene_dist.svg', format='svg', dpi=300, bbox_inches='tight')


In [None]:
df_gene_freq_complete = df_genes_complete.sum(axis=1)

fig, ax = plt.subplots()
sns.histplot(df_gene_freq_complete, binwidth=50, ax=ax)
plt.yscale('log')
plt.show()

In [None]:
df_freq_complete = df_genes_complete.sum(axis=1)

df_core_complete = df_genes_complete[df_freq_complete > np.floor(segments[0])]
df_rare_complete = df_genes_complete[df_freq_complete < np.ceil(segments[1])]

acc_gene_list_complete = list(set(df_genes_complete.index)
                     - set(df_core_complete.index)
                     - set(df_rare_complete.index)
                    )

df_acc_complete = df_genes_complete.loc[acc_gene_list_complete].copy()

display(
    df_core_complete.shape,
    df_acc_complete.shape,
    df_rare_complete.shape
)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(10,6))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

bin_width = gcd(int(segments[0]), int(segments[1]))

# plot the same data on both axes
sns.histplot(df_gene_freq_complete, binwidth=bin_width, ax=ax1, linewidth=1, edgecolor='black')
sns.histplot(df_gene_freq_complete, binwidth=bin_width, ax=ax2, linewidth=1, edgecolor='black')
# zoom-in / limit the view to different portions of the data
ax1.set_ylim(50000, 60000)  # outliers only
ax2.set_ylim(0, 3500)  # most of the data

# hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

# Remove y-axis label on ax2
ax2.set_ylabel('')

# Adjust y-axis label position on ax1
ax1.yaxis.set_label_coords(-0.085, 0.0)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

ax1.set_title("Enterobacter Gene Frequency Plot")

ax1.axvline(x=segments[0], color='red', linestyle='--', linewidth=2)
ax2.axvline(x=segments[0], color='red', linestyle='--', linewidth=2)

ax1.axvline(x=segments[1], color='red', linestyle='--', linewidth=2)
ax2.axvline(x=segments[1], color='red', linestyle='--', linewidth=2)

ax1.set_xlim(right=501)

plt.savefig('../images/gene_freq.svg', format='svg', dpi=300, bbox_inches='tight')

plt.show()

In [None]:
np.log10(df_gene_freq_complete).max()

In [None]:
from math import gcd
bin_width = gcd(int(segments[0]), int(segments[1]))

fig, (ax1) = plt.subplots(1, 1, figsize=(10,6))

# plot the same data on both axes
sns.histplot(df_gene_freq_complete, binwidth=bin_width, ax=ax1, log_scale=(False, 10))

# ax1.set_ylim(bottom=10)

ax1.set_title("Enterobacter Gene Frequency Plot")

ax1.axvline(x=segments[0], color='red', linestyle='--')

ax1.axvline(x=segments[1], color='red', linestyle='--')

# plt.yscale('log')
plt.savefig('../images/gene_freq.svg', format='svg', dpi=600)
plt.show()

# Save results

In [None]:
# Total
df_core.to_pickle('../../data/processed/CAR_genomes/df_core.pickle')
df_acc.to_pickle('../../data/processed/CAR_genomes/df_acc.pickle')
df_rare.to_pickle('../../data/processed/CAR_genomes/df_rare.pickle')

# Complete
df_core_complete.to_pickle('../../data/processed/CAR_genomes/df_core_complete.pickle')
df_acc_complete.to_pickle('../../data/processed/CAR_genomes/df_acc_complete.pickle')
df_rare_complete.to_pickle('../../data/processed/CAR_genomes/df_rare_complete.pickle')