# Sampling genomes for phylogenetic inference

## Summary

Select ~10k genomes out of all 86,200 reference genomes, such that they represent the largest possible biodiversity, as measured by _k_-mer signature, plus multiple other criteria concerning genome quality and marker gene count, etc.

## Preparation

### Dependencies

In [1]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
from skbio.stats.distance import DistanceMatrix

In [3]:
%matplotlib inline

In [4]:
sns.set()

### Input files

Genome metadata

In [None]:
meta_fp = 'metadata.ext.tsv.xz'

In [None]:
df = pd.read_csv(meta_fp, sep='\t', index_col=0, low_memory=False)
df.shape[0]

Genome distance matrix (calculated using MinHash)

In [5]:
dm_fp = 'minhash.dm.bz2'

In [None]:
%time dm = DistanceMatrix.read(dm_fp)

### Helpers

Core algorithm: **prototype selection**

In [None]:
def prototype_selection_destructive_maxdist(dm, num_prototypes, seedset=None):
    """Prototype selection function (minified)."""
    numRemain = len(dm.ids)
    currDists = dm.data.sum(axis=1)
    maxVal = currDists.max()
    if seedset is not None:
        for e in seedset:
            currDists[dm.index(e)] = maxVal * 2
    minElmIdx = currDists.argmin()
    currDists[minElmIdx], numRemain = np.infty, numRemain - 1
    while (numRemain > num_prototypes):
        currDists -= dm.data[minElmIdx]
        minElmIdx = currDists.argmin()
        currDists[minElmIdx], numRemain = np.infty, numRemain - 1
    return [dm.ids[idx]
            for idx, dist in enumerate(currDists)
            if dist != np.infty]

In [None]:
def distance_sum(elements, dm):
    """Calculate sum of distances among chosen genomes."""
    return np.tril(dm.filter(elements).data).sum()

Generate a histogram with an upper fence.

In [None]:
def hist_w_max(data, step, xmax):
    """Plot histogram with values > certain threshold combined in one bin."""
    bins = np.arange(0, xmax + step * 2, step)
    plt.xticks(step * np.arange(len(bins)), [str(x) for x in plt.hist(np.clip(data, bins[0],
        bins[-1]), bins=bins)[1]][:-1]+ ['Inf']);

Generate a count plot

In [None]:
def count_plot(data, percent=False, **kwargs):
    """Plot bars representing categorical counts."""
    ax = sns.countplot(data, **kwargs)
    total = float(len(data))
    for p in ax.patches:
        height = p.get_height()
        ax.text(p.get_x() + p.get_width() / 2., height + 3,
                '{:1.2f}'.format(height / total) if percent else str(int(height)),
                ha='center')
    ax.set_xlabel('')
    return ax

Generate a multi-panel figure to show the statistics of currently selected genomes

In [None]:
def examine_set(genomes):
    """Plot statistics of chosen genomes."""
    print('%d' % len(genomes))
    gs = mpl.gridspec.GridSpec(1, 5, width_ratios=[1, 1, 1, 1, 2])
    cols = ['markers', 'completeness', 'contamination', 'score_fna']
    ymaxes = [400, 100, 600, 1.0]
    dfc = df[df.index.isin(genomes)]
    for i, col in enumerate(cols):
        ax = plt.subplot(gs[i])
        ax.plot(dfc[col].sort_values().tolist(), color='C%d' % i)
        ax.set_title(col)
        ax.set_ylim([0, ymaxes[i]])
    m_taxa = df[df.index.isin(genomes)][ranks].apply(pd.Series.nunique)
    p_taxa = m_taxa / n_taxa * 100
    ax = plt.subplot(gs[4])
    sns.barplot(x=p_taxa.index, y=p_taxa, ax=ax)
    for i, p in enumerate(ax.patches):
        height = p.get_height()
        ax.text(p.get_x() + p.get_width() / 2., height + 3, m_taxa[ranks[i]], ha='center')
    ax.set_ylim([0, 110])
    ax.set_xlabel('');
    ax.set_xticklabels(['p', 'c', 'o', 'f', 'g', 's', 't'])
    ax.set_title('% included')
    plt.tight_layout();

## Analysis

### Genome metadata

Plot distributions of some metadata fields

In [None]:
mpl.rcParams['figure.figsize'] = (8, 2.5)

In [None]:
hist_w_max(df['total_length'] / 1000000, 1, 10)
plt.title('Genome size (Mbp)');

In [None]:
hist_w_max(df['proteins'], 1000, 10000)
plt.title('Number of proteins per genome');

In [None]:
hist_w_max(df['protein_length'] / 1000, 250, 3000)
plt.title('Total length of proteins per genome (kaa)');

In [None]:
bins = [0, 50, 80, 90, 95, 97.5, 99, 99.5, 99.9, 100]
plt.bar(range(len(bins) - 1), np.histogram(df['completeness'], bins=bins)[0], width=1,
        align='edge')
plt.xticks(np.arange(len(bins)), [str(x) for x in bins])
plt.title('Completeness (%)');

In [None]:
bins = [0, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 50, 100]
plt.bar(range(len(bins) - 1), np.histogram(df['contamination'], bins=bins)[0], width=1,
        align='edge')
plt.xticks(np.arange(len(bins)), [str(x) for x in bins])
plt.title('Contamination (%)');

In [None]:
bins = np.arange(0, 425, 25)
plt.hist(df['markers'], bins=bins)
plt.title('Number of PhyloPhlAn marker genes');

In [None]:
plt.hist(df['score_fna'])
plt.title('RepoPhlAn score of genome sequence (fna)');

In [None]:
plt.hist(df['n50'], log=True)
plt.title('N50 of scaffolds');

In [None]:
order=['Complete Genome', 'Chromosome', 'Scaffold', 'Contig']
count_plot(df['assembly_level'], order=order).set_title('Assembly level');

In [None]:
count_plot(df['refseq_category'], log=True).set_title('RefSeq category');

In [None]:
count_plot(df['release_type']).set_title('Release type');

In [None]:
ranks = ['phylum', 'class', 'order', 'family', 'genus', 'species']
df[ranks].describe()

In [None]:
mpl.rcParams['figure.figsize'] = (9, 3)
n_taxa = df[ranks].apply(pd.Series.nunique)
ax = sns.barplot(x=n_taxa.index, y=n_taxa, log=True)
for p in ax.patches:
    height = p.get_height()
    ax.text(p.get_x() + p.get_width() / 2., height + 3, str(int(height)), ha='center')
ax.set_xlabel('')
ax.set_title('Number of taxonomic groups per rank');

In [None]:
mpl.rcParams['figure.figsize'] = (12, 4)
f, axarr = plt.subplots(1, 3)
axarr[0].scatter(x='completeness', y='contamination', alpha=0.25, data=df)
axarr[0].set_ylim(ymax=200)
axarr[0].set_xlabel('Completeness (%)')
axarr[0].set_ylabel('Contamination (%)');
axarr[1].scatter(x='markers', y='completeness', alpha=0.25, data=df)
axarr[1].set_xlabel('marker gene count')
axarr[1].set_ylabel('Completeness (%)');
axarr[2].scatter(x='markers', y='contamination', alpha=0.25, data=df)
axarr[2].set_ylim(ymax=200)
axarr[2].set_xlabel('marker gene count')
axarr[2].set_ylabel('Contamination (%)')
plt.suptitle('Distribution of marker gene count, contamination and completeness');

Make a one-figure summary

In [None]:
examine_set(df.index)

### Criteria for sampling

In [None]:
mpl.rcParams['figure.figsize'] = (12, 2.5)

Quality filtering

In [None]:
qualified = set(df.query('markers >= 100').index.tolist())
examine_set(qualified)

In [None]:
qualified = set(df.query('contamination <= 10').index.tolist())
examine_set(qualified)

In [None]:
qualified = set(df.query('completeness >= 80').index.tolist())
examine_set(qualified)

NCBI reference and representative genomes

In [None]:
refp = set(df[df['refseq_category'] != 'na'].index.tolist())
examine_set(refp)

Only representatives of taxonomic groups

In [None]:
singles = {}
reports = []
for rank in ('phylum', 'class', 'order', 'family', 'genus', 'species'):
    single_taxa = [i for i, val in df[rank].value_counts().iteritems() if val == 1]
    singles[rank] = set(df[df[rank].isin(single_taxa)].index.tolist())
    reports.append('%s: %d' % (rank, len(singles[rank])))
print('Taxonomic groups with only one representative: %s' % ', '.join(reports))

In [None]:
examine_set(singles['phylum'])

In [None]:
examine_set(singles['class'])

In [None]:
examine_set(singles['order'])

In [None]:
examine_set(singles['family'])

In [None]:
examine_set(singles['genus'])

In [None]:
examine_set(singles['species'])

No defined taxonomy above species

In [None]:
df_notax = df[df['phylum'].isnull() & df['class'].isnull()
               & df['order'].isnull() & df['family'].isnull()
               & df['genus'].isnull() & df['species'].notnull()]
print('%d genomes belonging to %d species'
      % (df_notax.shape[0], df_notax['species'].nunique()))

In [None]:
examine_set(df_notax.index)

In [None]:
single_notax = df_notax.groupby('species').filter(lambda x: len(x) == 1).index
examine_set(single_notax)

Prototype selection

In [None]:
k = 11000  # number of prototypes to keep

In [None]:
%time prototypes = prototype_selection_destructive_maxdist(dm, k)
print('Sum of distances: %d.' % distance_sum(prototypes, dm))

In [None]:
examine_set(prototypes)

### Formal sampling

Procedures:
 1. Exclude genomes with contamination > 10% or marker gene count < 100.
 2. Include NCBI reference and representative genomes.
 3. Include only representatives of each phylum to genus.
 4. Include only representatives of each species without defined lineage.
 5. Run prototype selection based on the MinHash distance matrix, with already included genomes as seeds, to obtain a total of 11000 genomes.
 6. For each phylum to genus, and species without defined lineage, select one with highest marker gene count.

In [None]:
chosen = set()

Step 1: Exclude genomes with contamination > 10% or marker gene count < 100

In [None]:
dfp = df.query('contamination <= 10 and markers >= 100')
print('Genomes passed quality filtering: %d' % dfp.shape[0])

In [None]:
examine_set(dfp.index)

Step 2: Include NCBI reference and representative genomes

In [None]:
refp = dfp[dfp['refseq_category'] != 'na'].index
chosen.update(refp)
print('Added NCBI reference and representative genomes: %d' % len(refp))

In [None]:
examine_set(chosen)

Step 3: Include only representatives of each phylum to genus

In [None]:
reports = []
for rank in ('phylum', 'class', 'order', 'family', 'genus'):
    single_taxa = dfp.groupby(rank).filter(lambda x: len(x) == 1).index
    toadd = set(single_taxa) - chosen
    chosen.update(toadd)
    reports.append('%s: %s' % (rank, len(toadd)))
print('Added taxonomic groups: %s' % ', '.join(reports))

In [None]:
examine_set(chosen)

Step 4: Include only representatives of each species without defined lineage

In [None]:
dfp_notax = dfp[dfp['phylum'].isnull() & dfp['class'].isnull()
                & dfp['order'].isnull() & dfp['family'].isnull()
                & dfp['genus'].isnull() & dfp['species'].notnull()]
single_notax = dfp_notax.groupby('species').filter(lambda x: len(x) == 1).index
toadd = set(single_notax) - chosen
chosen.update(toadd)
print('Added species without lineage: %d' % len(toadd))

In [None]:
examine_set(chosen)

Step 5: Run prototype selection based on the MinHash distance matrix, with already included genomes as seeds, to obtain a total of 11000 genomes

In [None]:
%time dmp = dm.filter(dfp.index).copy()
len(dmp.ids)

In [None]:
%time prototypes = prototype_selection_destructive_maxdist(dmp, 11000, chosen)
print('Sum of distances: %d.' % distance_sum(prototypes, dmp))

In [None]:
toadd = set(prototypes) - chosen
chosen.update(toadd)
print('Added prototypes: %d' % len(toadd))

In [None]:
examine_set(chosen)

Step 6: For each phylum to genus, and species without defined lineage, select one with highest marker gene count

In [None]:
reports = []
dfc = dfp[dfp.index.isin(chosen)]
for rank in ('phylum', 'class', 'order', 'family', 'genus'):
    toadd = set()
    chosen_taxa = dfc[rank].dropna().unique()
    for taxon in dfp[rank].dropna().unique():
        if not taxon in chosen_taxa:
            toadd.add(dfp[dfp[rank] == taxon]['markers'].idxmax())
    chosen.update(toadd)
    reports.append('%s: %d' % (rank, len(toadd)))
    dfc = dfp[dfp.index.isin(chosen)]
print('Added taxonomic groups: %s' % ', '.join(reports))

In [None]:
toadd = set()
dfc_notax = dfc[dfc.index.isin(dfp_notax.index)]
for species in dfp_notax[~dfp_notax.index.isin(single_notax)]['species'].unique():
    if dfc_notax.query('species == "%s"' % species).shape[0] == 0:
        toadd.add(dfp_notax.query('species == "%s"' % species)['markers'].idxmax())
chosen.update(toadd)
dfc = dfp[dfp.index.isin(chosen)]
print('Added species without lineage: %d' % len(toadd))

In [None]:
examine_set(chosen)

In [None]:
dfc['markers'].describe()

In [None]:
mpl.rcParams['figure.figsize'] = (12, 4)
f, axarr = plt.subplots(1, 3)
axarr[0].scatter(x='completeness', y='contamination', alpha=0.25, data=dfc)
axarr[0].set_xlabel('Completeness (%)')
axarr[0].set_ylabel('Contamination (%)');
axarr[1].scatter(x='markers', y='completeness', alpha=0.25, data=dfc)
axarr[1].set_xlabel('marker gene count')
axarr[1].set_ylabel('Completeness (%)');
axarr[2].scatter(x='markers', y='contamination', alpha=0.25, data=dfc)
axarr[2].set_xlabel('marker gene count')
axarr[2].set_ylabel('Contamination (%)')
plt.suptitle('Distribution of marker gene count, contamination and completeness');

Export results

In [None]:
with open('sampled.txt', 'w') as f:
    for g in sorted(chosen):
        f.write('%s\n' % g)

In [None]:
df['chosen'] = df.index.isin(chosen)

In [None]:
df.to_csv('summary_sampled.tsv', sep='\t')