## Setup

__Note:__ This analysis requires a lot of temporary storage space (~200GB) in the data fetch phase. Subsequently, around ~50GB are required to store all the output artifacts.

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import qiime2 as q2
import seaborn as sns
import skbio

from matplotlib import cm
from qiime2.plugins import (
    fondue, sourmash, diversity, emperor, demux, 
    sample_classifier, cutadapt
)

data_loc = 'u2-genome-results'
if not os.path.isdir(data_loc):
    os.mkdir(data_loc)

email = 'your@email.com'
n_jobs = 16

nextstrain_metadata_path = os.path.join(data_loc, 'metadata_nextstrain.tsv')
nextstrain_meta_url = 'https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz'
nextstrain_last_submit_date = '2022-01-31'

genomes_per_variant = 250
random_seed = 11

sra_metadata_path = os.path.join(data_loc, 'metadata_sra.tsv')
metadata_merged_path = os.path.join(data_loc, 'metadata_merged.tsv')

In [None]:
def sample_variants(metadata_df, n, grouping_col='Nextstrain_clade', random_state=1):
    """Draw a random, stratified sample from all available virus variants.
    
    Args:
        metadata_df (pd.DataFrame): Metadata of all samples.
        n (int): Sample size per virus variant.
        grouping_col (str): Name of the column containing variant name.
        random_state (int): Random seed to be used when sampling.
    
    Returns:
        pd.DataFrame: DataFrame containing subsampled metadata.
    """
    metadata_ns_vars_smp = metadata_df.groupby(grouping_col).apply(
        lambda x: x.sample(n=n, random_state=random_state)
    )
    if 'sra_accession' in metadata_ns_vars_smp.columns:
        metadata_ns_vars_smp.set_index('sra_accession', drop=True, inplace=True)
    else:
        metadata_ns_vars_smp.reset_index(level=0, drop=True, inplace=True)
    metadata_ns_vars_smp.index.name = 'id'
    return metadata_ns_vars_smp


def color_variants(x, cmap='plasma'):
    """
    Return a color from provided color map based on virus variant.
    
    Args:
        x (str): Variant name.
        cmap (str): Matplotlib's color map name.
    
    Returns:
        Color from Matplotlib's cmap.
    """
    colors = cm.get_cmap(cmap, 8).colors
    if x == 'Alpha':
        return colors[0]
    elif x == 'Delta':
        return colors[1]
    else: 
        return colors[2]

## Process NextStrain's metadata

We are interested in taking a sample of SARS-CoV-2 genomes from the full Nextstrain list. We will only consider genomes available in the SRA repository for a few virus variants. Moreover, we will only work with single-end sequences to simplify the analysis. We begin by fetching the original Nextstrain metadata:

In [None]:
%%bash -s "$nextstrain_metadata_path" "$data_loc" "$nextstrain_meta_url"

if test -f "$1"; then
    echo "$1 exists and will not be re-downloaded."
else
    wget -nv -O "$2/metadata.tsv.gz" "$3";
    gzip -f -d "$2/metadata.tsv.gz";
    mv "$2/metadata.tsv" "$2/metadata_nextstrain.tsv"
fi

In [None]:
metadata_ns = pd.read_csv(nextstrain_metadata_path, sep='\t')
metadata_ns.shape

In [None]:
metadata_ns.head(5)

In [None]:
# remove the records obtained later than the indicated date
metadata_ns['date_submitted'] = pd.to_datetime(metadata_ns['date_submitted'])
metadata_ns = metadata_ns[metadata_ns['date_submitted'] <= nextstrain_last_submit_date]
metadata_ns.shape

In [None]:
# convert date_submitted back to string (to conform with QIIME 2' Metadata format)
metadata_ns['date_submitted'] = metadata_ns['date_submitted'].astype(str)

In [None]:
# check count of samples per variant
metadata_ns['Nextstrain_clade'].value_counts()

Only keep samples with SRA accession numbers.

In [None]:
metadata_ns = metadata_ns[
    metadata_ns['sra_accession'].notna() & \
    metadata_ns['sra_accession'].str.startswith(('SRR', 'ERR')) &
    (metadata_ns['sra_accession'].str.contains(',') == False)
]
metadata_ns.shape

Only keep samples with good `QC_missing_data`.

In [None]:
metadata_ns = metadata_ns[metadata_ns['QC_missing_data'] == 'good']
metadata_ns.shape

Filter out and group together only the desired virus variants.

In [None]:
variants = ['Alpha', 'Delta', 'Omicron']

metadata_ns_vars = metadata_ns[metadata_ns['Nextstrain_clade'].notna()]
metadata_ns_vars = metadata_ns_vars[metadata_ns_vars['Nextstrain_clade'].str.contains('|'.join(variants))]
metadata_ns_vars.shape

In [None]:
# rename clades to generate groups
metadata_ns_vars['Nextstrain_clade_grouped'] = metadata_ns_vars['Nextstrain_clade']
for variant in variants:
    metadata_ns_vars['Nextstrain_clade_grouped'] = \
    metadata_ns_vars['Nextstrain_clade_grouped'].str.replace(rf'.*{variant}.*', variant, regex=True)

clade_counts = metadata_ns_vars['Nextstrain_clade_grouped'].value_counts()

In [None]:
clade_counts

Sample _n_ sequences per virus variant. We are first sampling more than the final required count per variant - we will need some initial metadata to sample only the records with desired properties.

In [None]:
n = 150 * genomes_per_variant
metadata_ns_vars_smp = sample_variants(
    metadata_ns_vars, n=n, grouping_col='Nextstrain_clade_grouped', random_state=random_seed
)

In [None]:
# remove some columns with mixed types (we will not need those)
for col in ['clock_deviation']:
    metadata_ns_vars_smp.drop(col, axis=1, inplace=True)

In [None]:
metadata_ns_vars_smp['Nextstrain_clade'].value_counts()

## Fetch sample metadata using q2-fondue
Fetch metadata for the pre-selected sequences using q2-fondue's `get-metadata` action. We will then use this metadata to filter out samples containing single-end reads only and merge those with the original Nextstrain metadata. Finally, we will subsample those to get the final list of genomes to fetch, stratified per virus variant. 

In [None]:
# we will be fetching metadata in several batches, due to large ID count 
ids = metadata_ns_vars_smp.index.to_list()
ids_chunked = [ids[i:i + 4000] for i in range(0, metadata_ns_vars_smp.shape[0], 4000)]

In [None]:
all_meta = []
if not os.path.isfile(sra_metadata_path):
    for i, _ids in enumerate(ids_chunked):
        print(f'-----Fetching metadata - batch {i + 1} out of {len(ids_chunked)}...-----')
        current_batch_loc = os.path.join(data_loc, f'sra_meta_batch{i}.qza')
        _ids = pd.Series(_ids, name='ID')
        
        if not os.path.isfile(current_batch_loc):
            sra_meta, failed_ids, = fondue.methods.get_metadata(
                accession_ids=q2.Artifact.import_data('NCBIAccessionIDs', _ids),
                email=email,
                n_jobs=n_jobs,
                log_level='WARNING'
            )
            sra_meta.save(current_batch_loc)
            failed_ids.save(os.path.join(data_loc, f'sra_failed_ids_batch{i}.qza'))
        else:
            print(f'Reading current SRA meta batch from file {current_batch_loc}...')
            sra_meta = q2.Artifact.load(current_batch_loc)
            
        all_meta.append(sra_meta)
        del sra_meta
    
    # merge metadata from all the batches
    sra_meta, = fondue.methods.merge_metadata(
        metadata=all_meta
    )
    sra_meta_df = sra_meta.view(pd.DataFrame)
    sra_meta_df.to_csv(sra_metadata_path, sep='\t')
    
    # clean up
    del all_meta
else:
    print(f'Metadata artifact exists and will be read from {sra_metadata_path}.')
    sra_meta_df = pd.read_csv(sra_metadata_path, sep='\t', index_col=0)

## Merge SRA and Nextstrain metadata
Merge SRA metadata with Nextstrain metadata and re-sample only __single-end short__ reads.

In [None]:
sra_meta_smp_df = metadata_ns_vars_smp.merge(sra_meta_df, left_index=True, right_index=True)
sra_meta_smp_df.shape

In [None]:
selection = \
    (sra_meta_smp_df['Instrument'].str.contains('NextSeq 550')) & \
    (sra_meta_smp_df['Library Layout'] == 'SINGLE')

sra_meta_smp_df = sra_meta_smp_df[selection]

In [None]:
sra_meta_smp_df_gr = sra_meta_smp_df.groupby(['Nextstrain_clade_grouped']).count()
sra_meta_smp_df_gr.iloc[:,:2]

Find the largest possible sample size.

In [None]:
n = sra_meta_smp_df_gr.iloc[:, 0].min()
n = n if n < genomes_per_variant else genomes_per_variant
print(f'Taking a sample of {n} genomes per variant.')

In [None]:
sra_meta_smp_df = metadata_ns_vars_smp.merge(
    sra_meta_df, left_index=True, right_index=True
)

sra_meta_smp_df = sample_variants(
    sra_meta_smp_df[selection], n=n,
    grouping_col='Nextstrain_clade_grouped', random_state=random_seed
)
sra_meta_smp_df['Public'] = sra_meta_smp_df['Public'].astype(str)
sra_meta_smp_df.shape

In [None]:
# check count of samples per variant
sra_meta_smp_df['Nextstrain_clade_grouped'].value_counts()

Save merged & sampled metadata to file.

In [None]:
if not os.path.isfile(metadata_merged_path):
    sra_meta_smp_df.to_csv(metadata_merged_path, sep='\t')
    print('Saved sample metadata to', metadata_merged_path)

## Fetch SARS-CoV-2 genomes using q2-fondue
We can use IDs from our final metadata table to fetch all the corresponding sequencing files from the SRA using `q2-fondue`'s `get-sequences` action.

In [None]:
single_reads_out = os.path.join(data_loc, 'sars-single.qza')

In [None]:
if not os.path.isfile(single_reads_out):
    _ids = pd.Series(sra_meta_smp_df.index.to_list(), name='ID')
    single_reads, _, _ = fondue.methods.get_sequences(
        accession_ids=q2.Artifact.import_data('NCBIAccessionIDs', _ids), 
        email=email,
        n_jobs=n_jobs
    )
    single_reads.save(single_reads_out)
else:
    print(f'Single-reads artifact exists and will be read from {single_reads_out}.')
    single_reads = q2.Artifact.load(single_reads_out)

## Quality control of the sequences
Before proceeding to the next step, we can assess the quality of the retrieved dataset using the `summarize` action from the `q2-demux` plugin.

In [None]:
qc_viz_out = os.path.join(data_loc, 'qc-viz.qzv')
if not os.path.isfile(qc_viz_out):
    qc_viz, = demux.visualizers.summarize(
        data=single_reads
    )
    qc_viz.save(qc_viz_out)
else:
    print(f'Quality control artifact exists and will be read from {qc_viz_out}.')
    qc_viz = q2.Visualization.load(qc_viz_out)

In [None]:
qc_viz

## Data clean-up: sequence trimming
As can be seen in the visualization above, the data is already of good quality. We will just perform one additional cleaning step to remove sequences shorter than 35bp and with error rates higher than 0.01.

In [None]:
trimmed_out = os.path.join(data_loc, 'sars-single-trimmed.qza')
if not os.path.isfile(trimmed_out):
    single_reads_trimmed, = cutadapt.methods.trim_single(
        demultiplexed_sequences=single_reads,
        error_rate=0.01,
        minimum_length=35,
        cores=n_jobs
    )
    single_reads_trimmed.save(trimmed_out)
else:
    print(f'Trimmed reads artifact exists and will be read from {trimmed_out}.')
    single_reads_trimmed = q2.Artifact.load(trimmed_out)

In [None]:
trimmed_viz_out = os.path.join(data_loc, 'qc-viz-trimmed.qzv')
if not os.path.isfile(trimmed_viz_out):
    qc_viz_trimmed, = demux.visualizers.summarize(
        data=single_reads_trimmed
    )
    qc_viz_trimmed.save(trimmed_viz_out)
else:
    print(f'Trimmed reads visualization exists and will be read from {trimmed_viz_out}.')
    qc_viz_trimmed = q2.Visualization.load(trimmed_viz_out)
qc_viz_trimmed

## Calculate and compare MinHash signatures for every genome
Having checked the data quality, we will proceed to calculating the MinHash signatures of every genome using `q2-sourmash`. First, we calculate the hashes from the short reads using the `compute` action. Subsequently, we generate a distance matrix comparing hashes pairwise (using the `compare` action).

In [None]:
genome_hash_out = os.path.join(data_loc, 'genome-hash-trimmed.qza')
if not os.path.isfile(genome_hash_out):
    genome_hash, = sourmash.methods.compute(
        sequence_file=single_reads_trimmed,
        ksizes=31,
        scaled=10
    )
    genome_hash.save(genome_hash_out)
else:
    print(f'Genome hashes artifact exists and will be read from {genome_hash_out}.')
    genome_hash = q2.Artifact.load(genome_hash_out)

In [None]:
hash_compare_out = os.path.join(data_loc, 'hash-compare-trimmed.qza')
if not os.path.isfile(hash_compare_out):
    hash_compare, = sourmash.methods.compare(
        min_hash_signature=genome_hash,
        ksize=31
    )
    hash_compare.save(hash_compare_out)
else:
    print(f'Distance matrix artifact exists and will be read from {hash_compare_out}.')
    hash_compare = q2.Artifact.load(hash_compare_out)

## Perform dimensionality reduction of the genome MinHash distance matrix
Finally, a 2D t-SNE plot is generated from the obtained distance matrix (`tsne` method from the `q2-diversity` plugin) and visualized using the EMPeror plot (`plot` action from the `q2-emperor` plugin).

In [None]:
genome_tsne, = diversity.methods.tsne(
    distance_matrix=hash_compare,
    learning_rate=125,
    perplexity=18
)

In [None]:
cols = ['Nextstrain_clade_grouped', 'Nextstrain_clade', 'Instrument']

In [None]:
emperor_plot_out = os.path.join(data_loc, 'emperor-plot-trimmed.qzv')
if not os.path.isfile(emperor_plot_out):
    emperor_plot, = emperor.visualizers.plot(
        pcoa=genome_tsne,
        metadata=q2.Metadata(sra_meta_smp_df[cols])
    )
    emperor_plot.save(emperor_plot_out)
else:
    print(f'Emperor plot artifact exists and will be read from {emperor_plot_out}.')
    emperor_plot = q2.Visualization.load(emperor_plot_out)

In [None]:
emperor_plot

We can also use the results above to generate our own plots using any of the Python plotting libraries - see below.

In [None]:
tsne_table = genome_tsne.view(skbio.OrdinationResults)
tsne_df = tsne_table.samples

In [None]:
# switch to inline plotting
%matplotlib inline

In [None]:
# create a 2D plot of Dim1 vs Dim2

sns.set(rc={'figure.figsize':(8, 8), 'font.family': ['Arial']}, style='white')
with sns.plotting_context("notebook", font_scale=1.2):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.set_xlabel(f'Axis 1')
    ax.set_ylabel(f'Axis 2')
    
    sns.scatterplot(
        x=tsne_df.iloc[:, 0],
        y=tsne_df.iloc[:, 1],
        s=70,
        hue=sra_meta_smp_df['Nextstrain_clade_grouped'],
        ax=ax,
        alpha=0.75
    )

    ax.set_xticks([])
    ax.set_yticks([])
    
    plt.tight_layout()

In [None]:
fig.savefig(os.path.join(data_loc, 'sars_cov_2_tsne.eps'))

We can see from the plots above that when the `Nextclade_clade_grouped` column is used to color the data points, genomes group into distinct clusters corresponding to their variant assignment. The best visible is the Omicron variant that forms a single cluster, next to Alpha and Delta which both form their own (multiple smaller) clusters.

## Classify samples using hash signatures
We can also more quantitatively test whether MinHash genome signatures generated by _sourmash_ are predictive of SARS-CoV-2 genome variant. To do that we can use the `classify_samples_from_dist` method from the `q2-sample-classifier` plugin.

In [None]:
predictions_out = os.path.join(data_loc, 'predictions.qza')
accuracy_out = os.path.join(data_loc, 'accuracy.qzv')
if not os.path.isfile(predictions_out):
    predictions, accuracy, = sample_classifier.pipelines.classify_samples_from_dist(
        distance_matrix=hash_compare,
        metadata = q2.CategoricalMetadataColumn(sra_meta_smp_df['Nextstrain_clade_grouped']),
        k=3,
        cv=10,
        random_state=random_seed,
        n_jobs=n_jobs
    )
    predictions.save(predictions_out)
    accuracy.save(accuracy_out)
else:
    print(f'Classification artifacts exist and will be read from {predictions_out} and {accuracy_out}.')
    predictions = q2.Artifact.load(predictions_out)
    accuracy = q2.Visualization.load(accuracy_out)

In [None]:
accuracy

Virus variants can be sucessfully classified using only genome hashes with a high degree of accuracy (~93%). Importantly, a more detailed analysis would be required to investigate the potential influence of variables that were not controlled for in this analysis (e.g., geographic location of the samples). The `test_size = 0` warning can be disregarded here: classification was evaluated using 10-fold cross-validation.