# Mock community analysis using the MOSHPIT suite

In [None]:
import glob
import json
import sys
import os
import numpy as np
import pandas as pd
import shutil
import subprocess
import uuid
import urllib.request
import tempfile

import qiime2 as q2
from qiime2.plugins import (
    rescript, feature_table, assembly, 
    moshpit, sourmash, taxa as taxa_plugin
)

from utils._utils import *

In [None]:
# # proxy support, in case needed
# urllib.request.install_opener(
#     urllib.request.build_opener(
#         urllib.request.ProxyHandler(
#             {'http' : os.environ.get('http_proxy'), 
#              'https': os.environ.get('https_proxy')}
#         )
#     )
# )

Variables and constants used throughout this notebook:

In [None]:
THREADS = 24
SEED = 100
TOTAL_READS = 20_000_000
READ_LEN = 150

# define sample names (keys) and their respective parameters 
SAMPLES = {
    "uni20": {"profile": "uniform", "num_reads": TOTAL_READS},
    "exp20": {"profile": "exponential", "num_reads": TOTAL_READS},
    "log20": {"profile": "lognormal", "num_reads": TOTAL_READS},
}

Directories used throughout this notebook - all directories will be created automatically, if required:

In [None]:
data_dir = "./data"
genomes_dir = os.path.join(data_dir, "individual")
simulated_reads_dir = os.path.join(data_dir, "reads")

for d in [genomes_dir, simulated_reads_dir]:
    os.makedirs(d, exist_ok=True)

cache = q2.Cache(os.path.join(data_dir, "cache"))

## Fetch reference sequences
Start by defining a list of species which will be used to construct the mock community. Every taxon will get a UUID - this is required later in the pipeline when we generate MAGs.

In [None]:
taxa = {
    "paer": {
        "taxon": "Pseudomonas aeruginosa",
        "uuid": uuid.uuid4()
    },
    "ecol_k12": {
        "taxon": "511145", #Escherichia coli str. K-12 substr. MG1655
        "uuid": uuid.uuid4()
    },
    "ecol_o157": {
        "taxon": "Escherichia coli O157:H7 str. Sakai DNA",
        "uuid": uuid.uuid4()
    },
    "sent": {
        "taxon": "Salmonella enterica",
        "uuid": uuid.uuid4()
    },
    "saur": {
        "taxon": "Staphylococcus aureus",
        "uuid": uuid.uuid4()
    },
    "lmon": {
        "taxon": "Listeria monocytogenes",
        "uuid": uuid.uuid4()
    },
    "bsub": {
        "taxon": "Bacillus subtilis",
        "uuid": uuid.uuid4()
    },
    "mtub": {
        "taxon": "Mycobacterium tuberculosis",
        "uuid": uuid.uuid4()
    }
}

Use [RESCRIPt](https://github.com/bokulich-lab/RESCRIPt) plugin to fetch the sequences defined above. The `get-ncbi-genomes` method fetches genome sequences, corresponding taxonomies, protein and gene annotations - we will only keep the first two.

In [None]:
genomes_all, taxonomies_all, to_remove = [], [], []
if not os.path.isfile(os.path.join(data_dir, "cache", "keys", "ref_taxonomy")):
    ids = {}
    for abbrev, inner_dict in taxa.items():
        taxon = inner_dict["taxon"]
        _id = inner_dict["uuid"]
        
        directory = f"{genomes_dir}/{abbrev}"
        
        print(f"Fetching genome for {taxon}...")
        (genome, _, _, taxonomy) = rescript.methods.get_ncbi_genomes(
            taxon=taxon,
            assembly_source="refseq",
            assembly_levels=["complete_genome"],
            only_genomic=True,
            only_reference=True
        )
    
        genome_fp = os.path.join(genomes_dir, f"{abbrev}.qza")
        genome.save(genome_fp)
        with tempfile.TemporaryDirectory() as tmp:
            genome.export_data(tmp)
            
            src = os.path.join(tmp, "dna-sequences.fasta")
            dst = os.path.join(genomes_dir, f"{_id}.fasta")
            shutil.move(src, dst)
    
            with open(dst, "r") as f:
                accession_id = f.readline().split()[0][1:]
            
            ids[accession_id] = inner_dict
            
        genomes_all.append(genome)
        
        taxonomy_fp = os.path.join(genomes_dir, f"{abbrev}_taxonomy.qza")
        taxonomy.save(taxonomy_fp)
        taxonomies_all.append(taxonomy)
    
        to_remove.extend([genome_fp, taxonomy_fp])

Merge all the sequence files into a single artifact (repeat for the taxonomies) and save (we may need those later).

In [None]:
try:
    ref_taxonomy = cache.load("ref_taxonomy")
except KeyError:

    merged_seqs, = feature_table.methods.merge_seqs(data=genomes_all)
    merged_seqs.save(os.path.join(data_dir, f"genomes.qza"))
    
    merged_taxonomies, = feature_table.methods.merge_taxa(data=taxonomies_all)
    merged_taxonomies.save(os.path.join(data_dir, f"taxonomy.qza"));
    
    # update taxonomy with the new ids
    merged_taxonomies_ser = merged_taxonomies.view(pd.Series)
    merged_taxonomies_ser.index = merged_taxonomies_ser.index.map(
        {x: y["uuid"] for x, y in ids.items()}
    )
    
    ref_taxonomy = q2.Artifact.import_data(
        "FeatureData[Taxonomy]", merged_taxonomies_ser
    )
    cache.save(ref_taxonomy, "ref_taxonomy")

Clean up the `individual` directory.

In [None]:
for f in to_remove:
    os.remove(f)

Import the individual genomes into the `FeatureData[MAG]` artifact - we will use those later as our reference MAGs.

In [None]:
try:
    ref_genomes = cache.load("ref_genomes")
except KeyError:
    with tempfile.TemporaryDirectory() as tmp:
        for f in glob.glob(os.path.join(genomes_dir, "*.fasta")):
            shutil.copy(
                f, os.path.join(tmp, os.path.basename(f))
            )
            
        ref_genomes = q2.Artifact.import_data("FeatureData[MAG]", tmp)
        cache.save(ref_genomes, "ref_genomes")

## Simulate reads
Use the reference genomes to simulate a sequencing experiment according to the `SAMPLES` dictionary defined on top of the notebook. We first generate the abundance profiles, as per our spec, and then use the `mason_simulator` from the [SeqAn](https://github.com/seqan/seqan) library to generate reads from the references according to the abundance profle. Finally, we import all the samples and abundances into a QIIME 2 artifact.

In [None]:
# check if reads are already available
try:
    reads = cache.load("reads")
except KeyError:
    reads = None

In [None]:
if reads is None:
    abundances_all = []
    for sample_name, sample_details in SAMPLES.items():
        df = simulate_reads(
            genomes_dir=genomes_dir, 
            total_reads=sample_details["num_reads"], 
            abundance_profile=sample_details["profile"], 
            sample_name=sample_name, 
            simulated_reads_dir=simulated_reads_dir,
            threads=THREADS, read_len=READ_LEN, seed=SEED,
        )
        abundances_all.append(df)

Clean up the indices generated during read simulation.

In [None]:
for f in glob.glob(os.path.join(genomes_dir, "*.fasta.fai")):
    os.remove(f)

Construct the concatenated abundance table from all the samples and import into QIIME 2 artifact.

In [None]:
if reads is None:
    abundances_all = pd.concat(abundances_all, axis=1)
    abundances_artifact = q2.Artifact.import_data(
        "FeatureTable[RelativeFrequency]", abundances_all
    )
    cache.save(abundances_artifact, "abundances")
else:
    abundances_artifact = cache.load("abundances")

In [None]:
abundances_artifact.view(pd.DataFrame)

Finally, import the reads into a QIIME 2 artifact.

In [None]:
if reads is None:
    reads = q2.Artifact.import_data(
        "SampleData[PairedEndSequencesWithQuality]",
        simulated_reads_dir,
        "CasavaOneEightSingleLanePerSampleDirFmt",
    )
    cache.save(reads, "reads")

## Metagenome assembly
In this section we use the simulated reads to reconstruct the genomes which they originated from. We use actions from [q2-assembly](https://github.com/bokulich-lab/q2-assembly.git) and [q2-moshpit](https://github.com/bokulich-lab/q2-moshpit.git) plugins to assemble contigs, bin them into MAGs, filter the high quality MAGs and dereplicate them.

### Contig assembly
We begin by using MEGAHIT as our assembler of choice:

In [None]:
try:
    contigs = cache.load("contigs")
except KeyError:
    contigs, = assembly.pipelines.assemble_megahit(
        seqs=reads,
        presets="meta-sensitive",
        num_cpu_threads=THREADS,
    )
    cache.save(contigs, "contigs")

Next, we index the contigs and map the reads using Bowtie 2.

In [None]:
try:
    contigs_index = cache.load("contigs_index")
except KeyError:
    contigs_index, = assembly.pipelines.index_contigs(
        contigs=contigs,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(contigs_index, "contigs_index")

In [None]:
try:
    reads_to_contigs = cache.load("reads_to_contigs")
except KeyError:
    reads_to_contigs, = assembly.pipelines.map_reads(
        index=contigs_index,
        reads=reads,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(reads_to_contigs, "reads_to_contigs")

### Binning
We use the alignment maps to generate MAGs with MetaBat 2:

In [None]:
try:
    mags = cache.load("mags")
    contig_map = cache.load("contig_map")
    contigs_unbinned = cache.load("contigs_unbinned")
except:
    (mags, contig_map, contigs_unbinned) = moshpit.methods.bin_contigs_metabat(
        contigs=contigs,
        alignment_maps=reads_to_contigs,
        num_threads=THREADS,
        seed=SEED,
    )
    cache.save(mags, "mags")
    cache.save(contig_map, "contig_map")
    cache.save(contigs_unbinned, "contigs_unbinned")

To evaluate the quality of resulting MAGs we can use BUSCO. We start by fetching the prokaryotic BUSCO database, which we then use to estimate BUSCO counts in the recovered MAGs.

In [None]:
try:
    busco_db = cache.load("busco_db")
except KeyError:
    busco_db,  = moshpit.methods.fetch_busco_db(prok=True,)
    cache.save(busco_db, "busco_db")

In [None]:
try:
    busco_results = cache.load("busco_results")
    busco_viz = q2.Visualization.load(os.path.join(data_dir, "mags.qzv"))
except:
    (busco_results, busco_vis) = moshpit.pipelines.evaluate_busco(
        bins=mags,
        busco_db=busco_db,
        lineage_dataset="bacteria_odb10",
        cpu=THREADS
    )
    cache.save(busco_results, "busco_results")
    busco_vis.save(os.path.join(data_dir, "mags.qzv"))

In [None]:
q2.Visualization.load(os.path.join(data_dir, "mags.qzv"))

### MAG quality filtering
Before we continue with the analysis, we filter the MAGs based on their quality - we only want to retain those which were labeled by BUSCO as at least 90% complete.

In [None]:
try:
    mags_filtered = cache.load("mags_filtered")
except KeyError:
    mags_filtered, = moshpit.methods.filter_mags(
        mags=mags,
        metadata=busco_results.view(q2.Metadata),
        where="complete>90",
        on="mag",
    )
    cache.save(mags_filtered, "mags_filtered")

### MAG dereplication
In order to generate a dereplicated set of MAGs we will need a distance matrix. We can obtain it using the Sourmash tool: we first generate MinHash signatures of our MAGs and compare them to one another - this results in a distance matrix which we then input to the dereplication action.

In [None]:
try:
    mags_hashes = cache.load("mags_hashes")
except KeyError:
    mags_hashes, = sourmash.methods.compute(
        sequence_file=mags_filtered,
        ksizes=51,
        scaled=10000
    )
    cache.save(mags_hashes, "mags_hashes")

In [None]:
try:
    mags_dist = cache.load("mags_dist")
except KeyError:
    mags_dist, = sourmash.methods.compare(
        min_hash_signature=mags_hashes,
        ksize=51
    )
    cache.save(mags_dist, "mags_dist")

In [None]:
try:
    mags_derep = cache.load("mags_derep")
    mags_pa = cache.load("mags_pa")
except KeyError:
    (mags_derep, mags_pa) = moshpit.methods.dereplicate_mags(
        mags=mags_filtered,
        distance_matrix=mags_dist,
        threshold=0.98
    )
    cache.save(mags_derep, "mags_derep")
    cache.save(mags_pa, "mags_pa")

### MAG abundance estimation
We try to recover abundances of each MAG by mapping reads to the dereplicated MAGs and using the RPKM/TPM metrics as a proxy for abundance. We begin by indexing the MAGs, followed by read mapping and, finally, abundance estimation. 

In [None]:
try:
    mags_derep_index = cache.load("mags_derep_index")
except KeyError:
    mags_derep_index, = assembly.methods.index_derep_mags(
        mags=mags_derep,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(mags_derep_index, "mags_derep_index")

In [None]:
try:
    reads_to_derep_mags = cache.load("reads_to_derep_mags")
except KeyError:
    reads_to_derep_mags, = assembly.pipelines.map_reads(
        index=mags_derep_index,
        reads=reads,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(reads_to_derep_mags, "reads_to_derep_mags")

In [None]:
try:
    mags_derep_length = cache.load("mags_derep_length")
except KeyError:
    mags_derep_length, = moshpit.methods.get_feature_lengths(
        features=mags_derep,
    )
    cache.save(mags_derep_length, "mags_derep_length")

In [None]:
try:
    mags_rpkm = cache.load("mags_rpkm")
    mags_tpm = cache.load("mags_tpm")
except KeyError:
    mags_rpkm, = moshpit.methods.estimate_mag_abundance(
        mag_lengths=mags_derep_length,
        maps=reads_to_derep_mags,
        threads=THREADS,
        metric="rpkm",
        min_mapq=42
    )
    mags_tpm, = moshpit.methods.estimate_mag_abundance(
        mag_lengths=mags_derep_length,
        maps=reads_to_derep_mags,
        threads=THREADS,
        metric="tpm",
        min_mapq=42
    )
    cache.save(mags_tpm, "mags_rpkm")
    cache.save(mags_rpkm, "mags_tpm")

In [None]:
mags_rpkm.view(pd.DataFrame).T

In [None]:
mags_tpm.view(pd.DataFrame).T

## Taxonomic classification
The MOSHPIT pipeline supports a variety of ways to perform taxonomic classification. Here, we will use Kraken 2 as our classifier of choice and classify both, reads and recovered dereplicated MAGs.

### Databases
Fetch databases required for the classification actions below.

In [None]:
try:
    kraken_db = cache.load("kraken_db")
    bracken_db = cache.load("bracken_db")
except KeyError:
    (kraken_db, bracken_db) = moshpit.methods.build_kraken_db(
        collection="pluspf8"
    )
    cache.save(kraken_db, "kraken_db")
    cache.save(bracken_db, "bracken_db")

### Classification of reads
We use Kraken 2 to classify reads against the PlusPF database, followed by Bracken's abundance re-estimation.

In [None]:
try:
    kraken_reports_reads = cache.load("kraken_reports_reads")
    kraken_hits_reads = cache.load("kraken_hits_reads")
except KeyError:
    (kraken_reports_reads, kraken_hits_reads) = moshpit.pipelines.classify_kraken2(
        seqs=reads,
        kraken2_db=kraken_db,
        threads=2*THREADS,
        memory_mapping=True
    )
    cache.save(kraken_reports_reads, "kraken_reports_reads")
    cache.save(kraken_hits_reads, "kraken_hits_reads")

In [None]:
try:
    bracken_reports = cache.load("bracken_reports")
    bracken_taxonomy = cache.load("bracken_taxonomy")
    bracken_ft = cache.load("bracken_ft")
except KeyError:
    (bracken_reports, bracken_taxonomy, bracken_ft) = moshpit.methods.estimate_bracken(
        kraken_reports=kraken_reports_reads,
        bracken_db=bracken_db,
        read_len=READ_LEN,
        level="S"
    )
    cache.save(bracken_reports, "bracken_reports")
    cache.save(bracken_taxonomy, "bracken_taxonomy")
    cache.save(bracken_ft, "bracken_ft")

In [None]:
try:
    barplot_reads = q2.Visualization.load(os.path.join(data_dir, "reads-barplot.qzv"))
except:
    barplot_reads, =  taxa_plugin.visualizers.barplot(
        table=bracken_ft,
        taxonomy=bracken_taxonomy
    )
    barplot_reads.save(os.path.join(data_dir, "reads-barplot.qzv"))

### Classification of dereplicated MAGs
We use the same database to classify the dereplicated MAGs.

In [None]:
try:
    kraken_reports_derep = cache.load("kraken_reports_derep")
    kraken_hits_derep = cache.load("kraken_hits_derep")
except KeyError:
    (kraken_reports_derep, kraken_hits_derep) = moshpit.pipelines.classify_kraken2(
        seqs=mags_derep,
        kraken2_db=kraken_db,
        threads=2*THREADS,
        memory_mapping=True
    )
    cache.save(kraken_reports_derep, "kraken_reports_derep")
    cache.save(kraken_hits_derep, "kraken_hits_derep")

In [None]:
try:
    mags_derep_taxonomy = cache.load("mags_derep_taxonomy")
except KeyError:
    mags_derep_taxonomy, = moshpit.methods.kraken2_to_mag_features(
        reports=kraken_reports_derep,
        hits=kraken_hits_derep,
        coverage_threshold=10
    )
    cache.save(mags_derep_taxonomy, "mags_derep_taxonomy")

In [None]:
for i, item in mags_derep_taxonomy.view(pd.Series).items():
    print(f"{i}\n{item}\n")

In [None]:
try:
    barplot_derep = q2.Visualization.load(os.path.join(data_dir, "mags-rpkm-barplot.qzv"))
except:
    barplot_derep, =  taxa_plugin.visualizers.barplot(
        table=mags_rpkm,
        taxonomy=mags_derep_taxonomy
    )
    barplot_derep.save(os.path.join(data_dir, "mags-rpkm-barplot.qzv"))

In [None]:
q2.Visualization.load(os.path.join(data_dir, "mags-rpkm-barplot.qzv"))

## Abundance comparison
We can use the abundance estimation action to calculate the abundance of the reference genomes by mapping the simulated reads to the reference we obtained from NCBI.

In [None]:
try:
    ref_genomes_index = cache.load("ref_genomes_index")
except KeyError:
    ref_genomes_index, = assembly.methods.index_derep_mags(
        mags=ref_genomes,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(ref_genomes_index, "ref_genomes_index")

In [None]:
try:
    reads_to_ref_genomes = cache.load("reads_to_ref_genomes")
except KeyError:
    reads_to_ref_genomes, = assembly.pipelines.map_reads(
        index=ref_genomes_index,
        reads=reads,
        threads=THREADS,
        seed=SEED,
    )
    cache.save(reads_to_ref_genomes, "reads_to_ref_genomes")

In [None]:
try:
    ref_genomes_length = cache.load("ref_genomes_length")
except KeyError:
    ref_genomes_length, = moshpit.methods.get_feature_lengths(
        features=ref_genomes,
    )
    cache.save(ref_genomes_length, "ref_genomes_length")

In [None]:
try:
    ref_genomes_rpkm = cache.load("ref_genomes_rpkm")
except KeyError:
    ref_genomes_rpkm, = moshpit.methods.estimate_mag_abundance(
        mag_lengths=ref_genomes_length,
        maps=reads_to_ref_genomes,
        threads=THREADS,
        metric="rpkm",
        min_mapq=42
    )
    cache.save(ref_genomes_rpkm, "ref_genomes_rpkm")

In [None]:
try:
    barplot_ref_genomes = q2.Visualization.load(
        os.path.join(data_dir, "ref-genomes-rpkm-barplot.qzv")
    )
except:
    barplot_ref_genomes, =  taxa_plugin.visualizers.barplot(
        table=ref_genomes_rpkm,
        taxonomy=ref_taxonomy
    )
    barplot_ref_genomes.save(os.path.join(data_dir, "ref-genomes-rpkm-barplot.qzv"))

In [None]:
q2.Visualization.load(os.path.join(data_dir, "ref-genomes-rpkm-barplot.qzv"))

## Functional annotation
In this section we try to identify genes in the dereplicated MAGs and annotate them using EggNOG. We start by fetch the complete Diamond database and use it to identify ortholog candidates in our set of genomes. 

In [None]:
try:
    diamond_db = cache.load("diamond_db")
except KeyError:
    diamond_db,  = moshpit.methods.fetch_diamond_db()
    cache.save(diamond_db, "diamond_db")

In [None]:
try:
    eggnog_hits = cache.load("eggnog_hits")
    eggnog_ftf = cache.load("eggnog_ftf")
except:
    eggnog_hits, eggnog_ftf = moshpit.pipelines.eggnog_diamond_search(
        sequences=mags_derep,
        diamond_db=diamond_db,
        num_cpus=THREADS,
        db_in_memory=True
    )
    cache.save(eggnog_hits, "eggnog_hits")
    cache.save(eggnog_ftf, "eggnog_ftf")

We fetch the EggNOG annotation database and perform the annotation of orthologs identified by Diamond in the previous step.

In [None]:
try:
    eggnog_db = cache.load("eggnog_db")
except KeyError:
    eggnog_db,  = moshpit.methods.fetch_eggnog_db()
    cache.save(eggnog_db, "eggnog_db")

In [None]:
try:
    ortholog_annotations = cache.load("ortholog_annotations")
except KeyError:
    ortholog_annotations, = moshpit.pipelines.eggnog_annotate(
        eggnog_hits=eggnog_hits,
        eggnog_db=eggnog_db,
        num_cpus=THREADS,
        db_in_memory=True
    )
    cache.save(ortholog_annotations, "ortholog_annotations")

The previous step generated annotation tables for each MAG which we can now convert into feature tables. We pick `kegg_pathways` as our annotation of choice and run the `extract_annotations` action to extract those and expand them into a feature table containing pathway counts (MAGs x KEGG pathways).

In [None]:
try:
    cazymes = cache.load("cazymes")
except KeyError:
    cazymes, = moshpit.methods.extract_annotations(
        ortholog_annotations=ortholog_annotations,
        annotation="caz",
        max_evalue=0.0001
    )
    cache.save(cazymes, "cazymes")

Finally, to obtain the count of each pathway per sample (sample x KEGG pathway) we can calculate the dot product of the MAG abundance table with the table from the previous step:

In [None]:
try:
    cazymes_ft = cache.load("cazymes_ft")
except KeyError:
    cazymes_ft, = moshpit.pipelines.multiply_tables(
        table1=mags_tpm,
        table2=cazymes
    )
    cache.save(cazymes_ft, "cazymes_ft")

In [None]:
try:
    barplot_cazymes = q2.Visualization.load(
        os.path.join(data_dir, "cazymes-barplot.qzv")
    )
except:
    barplot_cazymes, =  taxa_plugin.visualizers.barplot(
        table=cazymes_ft,
    )
    barplot_cazymes.save(os.path.join(data_dir, "cazymes-barplot.qzv"))

In [None]:
q2.Visualization.load(os.path.join(data_dir, "cazymes-barplot.qzv"))