diff --git a/.condarc.yaml b/.condarc.yaml
deleted file mode 100644
index 90c78f428..000000000
--- a/.condarc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-channels:
- - defaults
- - bioconda
- - conda-forge
-show_channels_urls: True
-default_threads: 6
diff --git a/.github/ISSUE_TEMPLATE/support_request.md b/.github/ISSUE_TEMPLATE/support_request.md
new file mode 100644
index 000000000..c36491bec
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/support_request.md
@@ -0,0 +1,52 @@
+---
+name: Support request
+about: Need clarification on a few scripts or the pipeline in general? Ask us!
+---
+
+
+
+### User checklist
+
+- [ ] Are you using the latest release?
+- [ ] Are you using python 3?
+- [ ] Did you check previous issues to see if this has already been mentioned?
+- [ ] Are you using a Mac or Linux machine?
+
+#### Description
+
+
+
+#### Expected Behavior
+
+
+
+#### System Environment
+
+
+
+- Operating System:
+- RAM:
+- Disk:
+
+#### Tasks/Command(s)
+
+
+
+- [ ] Task 1
+- [ ] Task 2
+- [ ] Task 3
+- [ ] etc.
+
+Log/Error information generated by Autometa.
+
+
+```
+
+```
+
+
diff --git a/.gitignore b/.gitignore
index de5adf94f..dca11f12a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -138,13 +138,14 @@ dmypy.json
autometa/*.pyc
autometa/taxonomy/*.pyc
autometa/databases/markers/*.h3*
+autometa/databases/ncbi/*
# databases / testing
tests/data/*
!tests/data/metagenome.fna
# visualStudioCode
-.vscode/*
+.vscode/
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
diff --git a/autometa/binning/recursive_dbscan.py b/autometa/binning/recursive_dbscan.py
index 8cccd3fdf..f85825a0f 100644
--- a/autometa/binning/recursive_dbscan.py
+++ b/autometa/binning/recursive_dbscan.py
@@ -26,8 +26,11 @@
import logging
import os
+import shutil
+import tempfile
import pandas as pd
+import numpy as np
from Bio import SeqIO
from sklearn.cluster import DBSCAN
@@ -38,7 +41,7 @@
# TODO: This should be from autometa.common.kmers import Kmers
# So later we can simply/and more clearly do Kmers.load(kmers_fpath).embed(method)
-from autometa.common.exceptions import RecursiveDBSCANError
+from autometa.common.exceptions import BinningError
from autometa.taxonomy.ncbi import NCBI
pd.set_option("mode.chained_assignment", None)
@@ -47,14 +50,114 @@
logger = logging.getLogger(__name__)
+def add_metrics(df, markers_df, domain="bacteria"):
+ """Adds the completeness and purity metrics to each respective contig in df.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ index='contig' cols=['x','y','coverage','cluster']
+ markers_df : pd.DataFrame
+ wide format, i.e. index=contig cols=[marker,marker,...]
+ domain : str, optional
+ Kingdom to determine metrics (the default is 'bacteria').
+ choices=['bacteria','archaea']
+
+ Returns
+ -------
+ 2-tuple
+ `df` with added cols=['completeness', 'purity']
+ pd.DataFrame(index=clusters, cols=['completeness', 'purity'])
+
+ Raises
+ -------
+ KeyError
+ `domain` is not "bacteria" or "archaea"
+
+ """
+ domain = domain.lower()
+ marker_sets = {"bacteria": 139.0, "archaea": 162.0}
+ if domain not in marker_sets:
+ raise KeyError(f"{domain} is not bacteria or archaea!")
+ expected_number = marker_sets[domain]
+ clusters = dict(list(df.groupby("cluster")))
+ metrics = {"purity": {}, "completeness": {}}
+ for cluster, dff in clusters.items():
+ contigs = dff.index.tolist()
+ summed_markers = markers_df[markers_df.index.isin(contigs)].sum()
+ is_present = summed_markers >= 1
+ is_single_copy = summed_markers == 1
+ nunique_markers = summed_markers[is_present].index.nunique()
+ num_single_copy_markers = summed_markers[is_single_copy].index.nunique()
+ completeness = nunique_markers / expected_number * 100
+ # Protect from divide by zero
+ if nunique_markers == 0:
+ purity = pd.NA
+ else:
+ purity = num_single_copy_markers / nunique_markers * 100
+ metrics["completeness"].update({cluster: completeness})
+ metrics["purity"].update({cluster: purity})
+ metrics_df = pd.DataFrame(metrics, index=clusters.keys())
+ merged_df = pd.merge(df, metrics_df, left_on="cluster", right_index=True)
+ return merged_df, metrics_df
+
+
+def run_dbscan(df, eps, dropcols=["cluster", "purity", "completeness"]):
+ """Run clustering on `df` at provided `eps`.
+
+ Notes
+ -----
+
+ * documentation for sklearn `DBSCAN `_
+ * documentation for `HDBSCAN `_
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ Contigs with embedded k-mer frequencies as ['x','y'] columns and optionally 'coverage' column
+ eps : float
+ The maximum distance between two samples for one to be considered
+ as in the neighborhood of the other. This is not a maximum bound
+ on the distances of points within a cluster. This is the most
+ important DBSCAN parameter to choose appropriately for your data set
+ and distance function. See `DBSCAN docs `_ for more details.
+ dropcols : list, optional
+ Drop columns in list from `df`
+ (the default is ['cluster','purity','completeness']).
+
+ Returns
+ -------
+ pd.DataFrame
+ `df` with 'cluster' column added
+
+ Raises
+ -------
+ ValueError
+ sets `usecols` and `dropcols` may not share elements
+
+ """
+ for col in dropcols:
+ if col in df.columns:
+ df.drop(columns=col, inplace=True)
+ n_samples = df.shape[0]
+ if n_samples == 1:
+ clusters = pd.Series([pd.NA], index=df.index, name="cluster")
+ return pd.merge(df, clusters, how="left", left_index=True, right_index=True)
+ cols = ["x", "y"]
+ if "coverage" in df.columns:
+ cols.append("coverage")
+ if np.any(df.isnull()):
+ raise BinningError(
+ f"df is missing {df.isnull().sum().sum()} kmer/coverage annotations"
+ )
+ X = df.loc[:, cols].to_numpy()
+ clusterer = DBSCAN(eps=eps, min_samples=1, n_jobs=-1).fit(X)
+ clusters = pd.Series(clusterer.labels_, index=df.index, name="cluster")
+ return pd.merge(df, clusters, how="left", left_index=True, right_index=True)
+
+
def recursive_dbscan(
- table,
- markers_df,
- domain,
- completeness_cutoff,
- purity_cutoff,
- verbose=False,
- method="DBSCAN",
+ table, markers_df, domain, completeness_cutoff, purity_cutoff, verbose=False,
):
"""Carry out DBSCAN, starting at eps=0.3 and continuing until there is just one
group.
@@ -83,9 +186,6 @@ def recursive_dbscan(
`purity_cutoff` threshold to retain cluster (the default is 90.0).
verbose : bool
log stats for each recursive_dbscan clustering iteration.
- method : str
- clustering `method` to perform. (the default is 'DBSCAN').
- Choices=['DBSCAN','HDBSCAN'] NOTE: 'HDBSCAN' is still under development.
Returns
-------
@@ -105,11 +205,11 @@ def recursive_dbscan(
best_median = float("-inf")
best_df = pd.DataFrame()
while n_clusters > 1:
- binned_df = run_dbscan(table, eps, method=method)
- df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain)
- completess_filter = df["completeness"] >= completeness_cutoff
- purity_filter = df["purity"] >= purity_cutoff
- median_completeness = df[completess_filter & purity_filter][
+ binned_df = run_dbscan(table, eps)
+ df, metrics_df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain)
+ completeness_filter = metrics_df["completeness"] >= completeness_cutoff
+ purity_filter = metrics_df["purity"] >= purity_cutoff
+ median_completeness = metrics_df[completeness_filter & purity_filter][
"completeness"
].median()
if median_completeness >= best_median:
@@ -139,9 +239,9 @@ def recursive_dbscan(
logger.debug("No complete or pure clusters found")
return pd.DataFrame(), table
- completess_filter = best_df["completeness"] >= completeness_cutoff
+ completeness_filter = best_df["completeness"] >= completeness_cutoff
purity_filter = best_df["purity"] >= purity_cutoff
- complete_and_pure_df = best_df.loc[completess_filter & purity_filter]
+ complete_and_pure_df = best_df.loc[completeness_filter & purity_filter]
unclustered_df = best_df.loc[~best_df.index.isin(complete_and_pure_df.index)]
if verbose:
logger.debug(f"Best completeness median: {best_median:4.2f}")
@@ -151,30 +251,36 @@ def recursive_dbscan(
return complete_and_pure_df, unclustered_df
-def run_dbscan(
+def run_hdbscan(
df,
- eps,
+ min_cluster_size,
+ min_samples,
+ cache_dir=None,
dropcols=["cluster", "purity", "completeness"],
- usecols=["x", "y"],
- method="DBSCAN",
):
- """Run clustering via `method`.
+ """Run clustering on `df` at provided `min_cluster_size`.
+
+ Notes
+ -----
+
+ * reasoning for parameter: `cluster_selection_method `_
+ * reasoning for parameters: `min_cluster_size and min_samples `_
+ * documentation for `HDBSCAN `_
Parameters
----------
df : pd.DataFrame
- Description of parameter `df`.
- eps : float
- Description of parameter `eps`.
- dropcols : list
- Drop columns in list from `df` (the default is ['cluster','purity','completeness']).
- usecols : list
- Use columns in list for `df`.
- The default is ['x','y','coverage'] if 'coverage' exists in df.columns.
- else ['x','y','z'].
- method : str
- clustering `method` to perform. (the default is 'DBSCAN').
- Choices=['DBSCAN','HDBSCAN'] NOTE: 'HDBSCAN' is still under development
+ Contigs with embedded k-mer frequencies as ['x','y'] columns and optionally 'coverage' column
+ min_cluster_size : int
+ The minimum size of clusters; single linkage splits that contain
+ fewer points than this will be considered points "falling out" of a
+ cluster rather than a cluster splitting into two new clusters.
+ min_samples : int
+ The number of samples in a neighborhood for a point to be
+ considered a core point.
+ dropcols : list, optional
+ Drop columns in list from `df`
+ (the default is ['cluster','purity','completeness']).
Returns
-------
@@ -185,74 +291,133 @@ def run_dbscan(
-------
ValueError
sets `usecols` and `dropcols` may not share elements
- ValueError
- Method is not an available choice: choices=['DBSCAN','HDBSCAN']
+
"""
+ for col in dropcols:
+ if col in df.columns:
+ df.drop(columns=col, inplace=True)
n_samples = df.shape[0]
if n_samples == 1:
clusters = pd.Series([pd.NA], index=df.index, name="cluster")
return pd.merge(df, clusters, how="left", left_index=True, right_index=True)
- for col in dropcols:
- if col in df.columns:
- df.drop(columns=col, inplace=True)
cols = ["x", "y"]
- cols.append("coverage") if "coverage" in df.columns else cols.append("z")
+ if "coverage" in df.columns:
+ cols.append("coverage")
+ if np.any(df.isnull()):
+ raise BinningError(
+ f"df is missing {df.isnull().sum().sum()} kmer/coverage annotations"
+ )
X = df.loc[:, cols].to_numpy()
- if method == "DBSCAN":
- clustering = DBSCAN(eps=eps, min_samples=1, n_jobs=-1).fit(X)
- elif method == "HDBSCAN":
- clustering = HDBSCAN(
- cluster_selection_epsilon=eps,
- min_cluster_size=2,
- min_samples=1,
- allow_single_cluster=True,
- gen_min_span_tree=True,
- ).fit(X)
- else:
- raise ValueError(f"Method: {method} not a choice. choose b/w DBSCAN & HDBSCAN")
- clusters = pd.Series(clustering.labels_, index=df.index, name="cluster")
+ clusterer = HDBSCAN(
+ min_cluster_size=min_cluster_size,
+ min_samples=min_samples,
+ cluster_selection_method="leaf",
+ allow_single_cluster=True,
+ memory=cache_dir,
+ ).fit(X)
+ clusters = pd.Series(clusterer.labels_, index=df.index, name="cluster")
return pd.merge(df, clusters, how="left", left_index=True, right_index=True)
-def add_metrics(df, markers_df, domain="bacteria"):
- """Adds the completeness and purity metrics to each respective contig in df.
+def recursive_hdbscan(
+ table, markers_df, domain, completeness_cutoff, purity_cutoff, verbose=False,
+):
+ """Recursively run HDBSCAN starting with defaults and iterating the min_samples
+ and min_cluster_size until only 1 cluster is recovered.
Parameters
----------
- df : pd.DataFrame
- Description of parameter `df`.
+ table : pd.DataFrame
+ Contigs with embedded k-mer frequencies as ['x','y','z'] columns and
+ optionally 'coverage' column
markers_df : pd.DataFrame
wide format, i.e. index=contig cols=[marker,marker,...]
domain : str
Kingdom to determine metrics (the default is 'bacteria').
choices=['bacteria','archaea']
+ completeness_cutoff : float
+ `completeness_cutoff` threshold to retain cluster (the default is 20.0).
+ purity_cutoff : float
+ `purity_cutoff` threshold to retain cluster (the default is 90.0).
+ verbose : bool
+ log stats for each recursive_dbscan clustering iteration.
Returns
-------
- pd.DataFrame
- DataFrame with added columns - 'completeness' and 'purity'
+ 2-tuple
+ (pd.DataFrame(), pd.DataFrame())
+ DataFrames consisting of contigs that passed/failed clustering
+ cutoffs, respectively.
+
+ DataFrame:
+ index = contig
+ columns = ['x,'y','z','coverage','cluster','purity','completeness']
"""
- marker_sets = {"bacteria": 139.0, "archaea": 162.0}
- expected_number = marker_sets.get(domain.lower(), 139)
- clusters = dict(list(df.groupby("cluster")))
- metrics = {"purity": {}, "completeness": {}}
- for cluster, dff in clusters.items():
- contigs = dff.index.tolist()
- summed_markers = markers_df[markers_df.index.isin(contigs)].sum()
- is_present = summed_markers >= 1
- is_single_copy = summed_markers == 1
- nunique_markers = summed_markers[is_present].index.nunique()
- num_single_copy_markers = summed_markers[is_single_copy].index.nunique()
- completeness = nunique_markers / expected_number * 100
- # Protect from divide by zero
- if nunique_markers == 0:
- purity = pd.NA
+ max_min_cluster_size = 10000
+ max_min_samples = 10
+ min_cluster_size = 2
+ min_samples = 1
+ n_clusters = float("inf")
+ best_median = float("-inf")
+ best_df = pd.DataFrame()
+ cache_dir = tempfile.mkdtemp()
+ while n_clusters > 1:
+ binned_df = run_hdbscan(
+ table,
+ min_cluster_size=min_cluster_size,
+ min_samples=min_samples,
+ cache_dir=cache_dir,
+ )
+ df, metrics_df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain)
+
+ completeness_filter = metrics_df["completeness"] >= completeness_cutoff
+ purity_filter = metrics_df["purity"] >= purity_cutoff
+ median_completeness = metrics_df[completeness_filter & purity_filter][
+ "completeness"
+ ].median()
+ if median_completeness >= best_median:
+ best_median = median_completeness
+ best_df = df
+
+ n_clusters = df["cluster"].nunique()
+
+ if verbose:
+ logger.debug(
+ f"(min_samples, min_cluster_size): ({min_samples}, {min_cluster_size}) clusters: {n_clusters}"
+ f" median completeness (current, best): ({median_completeness:4.2f}, {best_median:4.2f})"
+ )
+
+ if min_cluster_size >= max_min_cluster_size:
+ shutil.rmtree(cache_dir)
+ cache_dir = tempfile.mkdtemp()
+ min_samples += 1
+ min_cluster_size = 2
else:
- purity = num_single_copy_markers / nunique_markers * 100
- metrics["completeness"].update({cluster: completeness})
- metrics["purity"].update({cluster: purity})
- metrics_df = pd.DataFrame(metrics, index=clusters.keys())
- return pd.merge(df, metrics_df, left_on="cluster", right_index=True)
+ min_cluster_size += 10
+
+ if metrics_df[completeness_filter & purity_filter].empty:
+ min_cluster_size += 100
+
+ if min_samples >= max_min_samples:
+ max_min_cluster_size *= 2
+
+ shutil.rmtree(cache_dir)
+
+ if best_df.empty:
+ if verbose:
+ logger.debug("No complete or pure clusters found")
+ return pd.DataFrame(), table
+
+ completeness_filter = best_df["completeness"] >= completeness_cutoff
+ purity_filter = best_df["purity"] >= purity_cutoff
+ complete_and_pure_df = best_df.loc[completeness_filter & purity_filter]
+ unclustered_df = best_df.loc[~best_df.index.isin(complete_and_pure_df.index)]
+ if verbose:
+ logger.debug(f"Best completeness median: {best_median:4.2f}")
+ logger.debug(
+ f"clustered: {len(complete_and_pure_df)} unclustered: {len(unclustered_df)}"
+ )
+ return complete_and_pure_df, unclustered_df
def get_clusters(
@@ -261,7 +426,7 @@ def get_clusters(
domain="bacteria",
completeness=20.0,
purity=90.0,
- method="DBSCAN",
+ method="dbscan",
verbose=False,
):
"""Find best clusters retained after applying `completeness` and `purity` filters.
@@ -281,8 +446,8 @@ def get_clusters(
purity : float
`purity` threshold to retain cluster (the default is 90.).
method : str
- Description of parameter `method` (the default is 'DBSCAN').
- choices = ['DBSCAN','HDBSCAN']
+ Description of parameter `method` (the default is 'dbscan').
+ choices = ['dbscan','hdbscan']
verbose : bool
log stats for each recursive_dbscan clustering iteration
@@ -293,15 +458,16 @@ def get_clusters(
"""
num_clusters = 0
clusters = []
+ recursive_clusterers = {"dbscan": recursive_dbscan, "hdbscan": recursive_hdbscan}
+ if method not in recursive_clusterers:
+ raise ValueError(f"Method: {method} not a choice. choose b/w dbscan & hdbscan")
+ clusterer = recursive_clusterers[method]
+
+ # Continue while unclustered are remaining
+ # break when either clustered_df or unclustered_df is empty
while True:
- clustered_df, unclustered_df = recursive_dbscan(
- master_df,
- markers_df,
- domain,
- completeness,
- purity,
- method=method,
- verbose=verbose,
+ clustered_df, unclustered_df = clusterer(
+ master_df, markers_df, domain, completeness, purity, verbose=verbose,
)
# No contigs can be clustered, label as unclustered and add the final df
# of (unclustered) contigs
@@ -314,7 +480,10 @@ def get_clusters(
c: f"bin_{1+i+num_clusters:04d}"
for i, c in enumerate(clustered_df.cluster.unique())
}
- rename_cluster = lambda c: translation[c]
+
+ def rename_cluster(c):
+ return translation[c]
+
clustered_df.cluster = clustered_df.cluster.map(rename_cluster)
# All contigs have now been clustered, add the final df of (clustered) contigs
@@ -336,8 +505,9 @@ def binning(
completeness=20.0,
purity=90.0,
taxonomy=True,
- method="DBSCAN",
- reverse=True,
+ starting_rank="superkingdom",
+ method="dbscan",
+ reverse_ranks=False,
verbose=False,
):
"""Perform clustering of contigs by provided `method` and use metrics to
@@ -353,23 +523,26 @@ def binning(
i.e. [taxid,superkingdom,phylum,class,order,family,genus,species]
markers : pd.DataFrame
wide format, i.e. index=contig cols=[marker,marker,...]
- domain : str
+ domain : str, optional
Kingdom to determine metrics (the default is 'bacteria').
choices=['bacteria','archaea']
- completeness : float
+ completeness : float, optional
Description of parameter `completeness` (the default is 20.).
- purity : float
+ purity : float, optional
Description of parameter `purity` (the default is 90.).
- taxonomy : bool
+ taxonomy : bool, optional
Split canonical ranks and subset based on rank then attempt to find clusters (the default is True).
taxonomic_levels = [superkingdom,phylum,class,order,family,genus,species]
- method : str
- Clustering `method` (the default is 'DBSCAN').
- choices = ['DBSCAN','HDBSCAN']
- reverse : bool
- True - [superkingdom,phylum,class,order,family,genus,species]
- False - [species,genus,family,order,class,phylum,superkingdom]
- verbose : bool
+ starting_rank : str, optional
+ Starting canonical rank at which to begin subsetting taxonomy (the default is superkingdom).
+ Choices are superkingdom, phylum, class, order, family, genus, species.
+ method : str, optional
+ Clustering `method` (the default is 'dbscan').
+ choices = ['dbscan','hdbscan']
+ reverse_ranks : bool, optional
+ False - [superkingdom,phylum,class,order,family,genus,species] (Default)
+ True - [species,genus,family,order,class,phylum,superkingdom]
+ verbose : bool, optional
log stats for each recursive_dbscan clustering iteration
Returns
@@ -379,44 +552,53 @@ def binning(
Raises
-------
- RecursiveDBSCANError
+ BinningError
No marker information is availble for contigs to be binned.
"""
# First check needs to ensure we have markers available to check binning quality...
if master.loc[master.index.isin(markers.index)].empty:
- err = "No markers for contigs in table. Unable to assess binning quality"
- raise RecursiveDBSCANError(err)
+ raise BinningError(
+ "No markers for contigs in table. Unable to assess binning quality"
+ )
+ logger.info(f"Using {method} clustering method")
if not taxonomy:
return get_clusters(
- master, markers, domain, completeness, purity, method, verbose
+ master_df=master,
+ markers_df=markers,
+ domain=domain,
+ completeness=completeness,
+ purity=purity,
+ method=method,
+ verbose=verbose,
)
# Use taxonomy method
- if reverse:
- # superkingdom, phylum, class, order, family, genus, species
- ranks = [rank for rank in reversed(NCBI.CANONICAL_RANKS)]
- else:
+ if reverse_ranks:
# species, genus, family, order, class, phylum, superkingdom
ranks = [rank for rank in NCBI.CANONICAL_RANKS]
+ else:
+ # superkingdom, phylum, class, order, family, genus, species
+ ranks = [rank for rank in reversed(NCBI.CANONICAL_RANKS)]
ranks.remove("root")
+ starting_rank_index = ranks.index(starting_rank)
+ ranks = ranks[starting_rank_index:]
+ logger.debug(f"Using ranks: {', '.join(ranks)}")
clustered_contigs = set()
num_clusters = 0
clusters = []
for rank in ranks:
# TODO: We should account for novel taxa here instead of removing 'unclassified'
unclassified_filter = master[rank] != "unclassified"
- n_contigs_in_taxa = (
- master.loc[unclassified_filter].groupby(rank)[rank].count().sum()
- )
- n_taxa = (
- master.loc[unclassified_filter].groupby(rank)[rank].count().index.nunique()
- )
+ master_grouped_by_rank = master.loc[unclassified_filter].groupby(rank)
+ taxa_counts = master_grouped_by_rank[rank].count()
+ n_contigs_in_taxa = taxa_counts.sum()
+ n_taxa = taxa_counts.index.nunique()
logger.info(
f"Examining {rank}: {n_taxa:,} unique taxa ({n_contigs_in_taxa:,} contigs)"
)
# Group contigs by rank and find best clusters within subset
- for rank_name_txt, dff in master.loc[unclassified_filter].groupby(rank):
+ for rank_name_txt, dff in master_grouped_by_rank:
if dff.empty:
continue
# Only cluster contigs that have not already been assigned a bin.
@@ -428,14 +610,20 @@ def binning(
if clustered_contigs:
rank_df = rank_df.loc[~rank_df.index.isin(clustered_contigs)]
# After all of the filters, are there multiple contigs to cluster?
- if len(rank_df) <= 1:
+ if rank_df.empty:
continue
# Find best clusters
logger.debug(
f"Examining taxonomy: {rank} : {rank_name_txt} : {rank_df.shape}"
)
clusters_df = get_clusters(
- rank_df, markers, domain, completeness, purity, method
+ master_df=rank_df,
+ markers_df=markers,
+ domain=domain,
+ completeness=completeness,
+ purity=purity,
+ method=method,
+ verbose=verbose,
)
# Store clustered contigs
is_clustered = clusters_df["cluster"].notnull()
@@ -447,7 +635,10 @@ def binning(
c: f"bin_{1+i+num_clusters:04d}"
for i, c in enumerate(clustered.cluster.unique())
}
- rename_cluster = lambda c: translation[c]
+
+ def rename_cluster(c):
+ return translation[c]
+
clustered.cluster = clustered.cluster.map(rename_cluster)
num_clusters += clustered.cluster.nunique()
clusters.append(clustered)
@@ -467,7 +658,10 @@ def main():
level=logger.DEBUG,
)
parser = argparse.ArgumentParser(
- description="Perform decomposition/embedding/clustering via PCA/[TSNE|UMAP]/DBSCAN."
+ description="Perform marker gene guided binning of "
+ "metagenome contigs using annotations (when available) of sequence "
+ "composition, coverage and homology.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("kmers", help="")
parser.add_argument("coverage", help="")
@@ -477,20 +671,41 @@ def main():
parser.add_argument(
"--embedding-method",
help="Embedding method to use",
- choices=["TSNE", "UMAP"],
- default="TSNE",
+ choices=["bhsne", "sksne", "umap"],
+ default="bhsne",
)
parser.add_argument(
"--clustering-method",
help="Clustering method to use",
- choices=["DBSCAN", "HDBSCAN"],
- default="DBSCAN",
+ choices=["dbscan", "hdbscan"],
+ default="dbscan",
)
parser.add_argument(
"--completeness", help="", default=20.0, type=float
)
parser.add_argument("--purity", help="", default=90.0, type=float)
parser.add_argument("--taxonomy", help="")
+ parser.add_argument(
+ "--starting-rank",
+ help="Canonical rank at which to begin subsetting taxonomy",
+ default="superkingdom",
+ choices=[
+ "superkingdom",
+ "phylum",
+ "class",
+ "order",
+ "family",
+ "genus",
+ "species",
+ ],
+ )
+ parser.add_argument(
+ "--reverse-ranks",
+ action="store_true",
+ default=False,
+ help="Reverse order at which to split taxonomy by canonical-rank."
+ " When --reverse-ranks given, contigs will be split in order of species, genus, family, order, class, phylum, superkingdom.",
+ )
parser.add_argument(
"--domain",
help="Kingdom to consider (archaea|bacteria)",
@@ -508,7 +723,7 @@ def main():
cov_df = pd.read_csv(args.coverage, sep="\t", index_col="contig")
master_df = pd.merge(
- kmers_df, cov_df[["coverage"]], how="left", left_index=True, right_index=True
+ kmers_df, cov_df[["coverage"]], how="left", left_index=True, right_index=True,
)
markers_df = Markers.load(args.markers)
@@ -532,6 +747,8 @@ def main():
master=master_df,
markers=markers_df,
taxonomy=taxa_present,
+ starting_rank=args.starting_rank,
+ reverse_ranks=args.reverse_ranks,
domain=args.domain,
completeness=args.completeness,
purity=args.purity,
diff --git a/autometa/common/exceptions.py b/autometa/common/exceptions.py
index a6fa86a8f..0e53579a3 100644
--- a/autometa/common/exceptions.py
+++ b/autometa/common/exceptions.py
@@ -67,8 +67,8 @@ def __str__(self):
return self.value
-class RecursiveDBSCANError(AutometaException):
- """RecursiveDBSCANError exception class."""
+class BinningError(AutometaException):
+ """BinningError exception class."""
def __init__(self, value):
self.value = value
diff --git a/autometa/common/external/diamond.py b/autometa/common/external/diamond.py
index 8330cc38d..9de1df1f6 100644
--- a/autometa/common/external/diamond.py
+++ b/autometa/common/external/diamond.py
@@ -29,6 +29,8 @@
import os
import subprocess
+import multiprocessing as mp
+
from itertools import chain
from tqdm import tqdm
@@ -39,38 +41,19 @@
class DiamondResult:
- """DiamondResult class
+ """
+ DiamondResult class to instantiate a DiamondResult object for each qseqid.
+ Allows for better handling of subject sequence IDs that hit to a qseqid.
- Parameters
- ----------
- qseqid : str
- query sequence ID
- sseqid : str
- subject sequence ID
- pident : float
- Percentage of identical matches.
- length : int
- Alignment length.
- mismatch : int
- Number of mismatches.
- gapopen : int
- Number of gap openings.
- qstart : int
- Start of alignment in query.
- qend : int
- End of alignment in query.
- sstart : int
- Start of alignment in subject.
- send : int
- End of alignment in subject sequence.
- evalue : float
- Expect value.
- bitscore : float
- Bitscore.
+ Includes methods used to modify the DiamondResult object (add or remove sseqid from sseqids dictionary),
+ check if two DiamondResult objects have the same qseqid and return a user friendly and unambiguous output
+ from str() and repr() respectively. Also used to return the sseqid with the highest bitscore amongst all
+ the subject sequences that hit a query. These methods come in handy when retrieving diamond results from output table.
Attributes
----------
sseqids : dict
+ All the subject sequences that hit to the query sequence
{sseqid:parameters, sseqid:parameters, ...}
qseqid: str
result query sequence ID
@@ -92,6 +75,36 @@ def __init__(
evalue,
bitscore,
):
+ """
+ Instantiates the DiamondResult class
+
+ Parameters
+ ----------
+ qseqid : str
+ query sequence ID
+ sseqid : str
+ subject sequence ID
+ pident : float
+ Percentage of identical matches.
+ length : int
+ Alignment length.
+ mismatch : int
+ Number of mismatches.
+ gapopen : int
+ Number of gap openings.
+ qstart : int
+ Start of alignment in query.
+ qend : int
+ End of alignment in query.
+ sstart : int
+ Start of alignment in subject.
+ send : int
+ End of alignment in subject sequence.
+ evalue : float
+ Expect value.
+ bitscore : float
+ Bitscore.
+ """
self.qseqid = qseqid
self.sseqids = {
sseqid: {
@@ -108,29 +121,110 @@ def __init__(
}
}
- # def __repr__(self):
- # return str(self)
+ def __repr__(self):
+ """
+ Operator overloading to return the representation of the class object
+
+ Returns
+ -------
+ str
+ Class name followed by query sequence ID
+ """
+ return f"Class: {self.__class__.__name__}, Query seqID: {self.qseqid}"
def __str__(self):
+ """
+ Operator overloading to return the string representation of the class objects
+
+ Returns
+ -------
+ str
+ String representation of query sequence ID, followed by total number of hits and finally the highest
+ bit score of all the hits
+ """
return f"{self.qseqid}; {len(self.sseqids)} sseqids; top hit by bitscore: {self.get_top_hit()}"
def __eq__(self, other_hit):
+ """
+ Operator overloading to compare two objects of the DiamondResult class
+
+ Parameters
+ ----------
+ other_hit : DiamondResult object
+ Other DiamondResult object to compare with
+
+ Returns
+ -------
+ Boolean
+ True if qseqid corresponding to both DiamondResult objects are equal, else False
+ """
if self.qseqid == other_hit.qseqid:
return True
else:
return False
def __add__(self, other_hit):
- assert self == other_hit, f"qseqids do not match! {self} & {other_hit}"
+ """
+ Operator overloading to update (add) the sseqids dictionary with the other sseqid hit dictionary.
+ The addition will only be successful if both the DiamondResult objects have the same qseqid
+
+ Parameters
+ ----------
+ other_hit : DiamondResult object
+ Other DiamondResult object to add
+
+ Returns
+ -------
+ DiamondResult object
+ DiamondResult object whose sseqids dict has been updated (added) with another sseqid
+
+ Raises
+ ------
+ AssertionError
+ Query sequences are not equal
+ """
+ assert (
+ self == other_hit
+ ), f"qseqids do not match! {self.qseqid} & {other_hit.qseqid}"
self.sseqids.update(other_hit.sseqids)
return self
def __sub__(self, other_hit):
- assert self == other_hit, f"qseqids do not match! {self} & {other_hit}"
- self.sseqids.pop(other_hit.sseqid)
+ """
+ Operator overloading to remove (subtract) the other sseqid hits dictionary from the sseqids dictionary.
+ The subtraction will only be successful if both the DiamondResult objects have the same qseqid
+
+ Parameters
+ ----------
+ other_hit : Dict
+ Other DiamondResult object to subtract
+
+ Returns
+ -------
+ DiamondResult object
+ DiamondResult object where a sseqid has been removed (subtracted) from the sseqids dict
+
+ Raises
+ ------
+ AssertionError
+ Query sequences are not equal
+ """
+ assert (
+ self == other_hit
+ ), f"qseqids do not match! {self.qseqid} & {other_hit.qseqid}"
+ for sseqid in other_hit.sseqids:
+ try:
+ self.sseqids.pop(sseqid)
+ except KeyError:
+ raise KeyError(
+ f"Given sseqid: {sseqid} is absent from the corresponding DiamondResult"
+ )
return self
def get_top_hit(self):
+ """
+ Returns the subject sequence ID with the highest bitscore amongst all the subject sequences that hit a query
+ """
top_bitscore = float("-Inf")
top_hit = None
for sseqid, attrs in self.sseqids.items():
@@ -140,13 +234,36 @@ def get_top_hit(self):
return top_hit
-def makedatabase(fasta, database, nproc=1):
- cmd = f"diamond makedb --in {fasta} --db {database} -p {nproc}"
- logger.debug(f"{cmd}")
- with open(os.devnull, "w") as stdout, open(os.devnull, "w") as stderr:
- retcode = subprocess.call(cmd, stdout=stdout, stderr=stderr, shell=True)
- if retcode:
- raise OSError(f"DiamondFailed:\nArgs:{proc.args}\nReturnCode:{proc.returncode}")
+def makedatabase(fasta, database, cpus=mp.cpu_count()):
+ """
+ Creates a database against which the query sequence would be blasted
+
+ Parameters
+ ----------
+ fasta : str
+ Path to fasta file whose database needs to be made
+ e.g. ''
+ database : str
+ Path to the output diamond formatted database file
+ e.g. ''
+ cpus : int, optional
+ Number of processors to be used. By default uses all the processors of the system
+
+ Returns
+ -------
+ str
+ Path to diamond formatted database
+
+ Raises
+ ------
+ subprocess.CalledProcessError
+ Failed to create diamond formatted database
+ """
+ cmd = ["diamond", "makedb", "--in", fasta, "--db", database, "-p", str(cpus)]
+ logger.debug(" ".join(cmd))
+ subprocess.run(
+ cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
+ )
return database
@@ -157,55 +274,57 @@ def blast(
blast_type="blastp",
evalue=float("1e-5"),
maxtargetseqs=200,
- cpus=0,
- tmpdir=os.curdir,
+ cpus=mp.cpu_count(),
+ tmpdir=None,
force=False,
verbose=False,
):
- """Performs diamond blastp search using fasta against diamond formatted database
+ """
+ Performs diamond blastp search using query sequence against diamond formatted database
Parameters
----------
fasta : str
- . May be amino acid or nucleotide sequences
+ Path to fasta file having the query sequences. Should be amino acid sequences in case of BLASTP
+ and nucleotide sequences in case of BLASTX
database : str
- .
+ Path to diamond formatted database
outfpath : str
- .
- blast_type : str
- blastp if fasta consists of amino acids or blastx if nucleotides
- evalue : float
- cutoff e-value to count hit as significant (the default is float('1e-5')).
- maxtargetseqs : int
- max number of target sequences to retrieve per query by diamond (the default is 200).
- cpus : int
- number of cpus to use (the default is 1).
- tmpdir : type
- (the default is os.curdir).
- force : boolean
- overwrite existing diamond results `force` (the default is False).
- verbose : boolean
- log progress to terminal `verbose` (the default is False).
+ Path to output file
+ blast_type : str, optional
+ blastp to align protein query sequences against a protein reference database,
+ blastx to align translated DNA query sequences against a protein reference database, by default 'blastp'
+ evalue : float, optional
+ cutoff e-value to count hit as significant, by default float('1e-5')
+ maxtargetseqs : int, optional
+ max number of target sequences to retrieve per query by diamond, by default 200
+ cpus : int, optional
+ Number of processors to be used, by default uses all the processors of the system
+ tmpdir : str, optional
+ Path to temporary directory. By default, same as the output directory
+ force : bool, optional
+ overwrite existing diamond results, by default False
+ verbose : bool, optional
+ log progress to terminal, by default False
Returns
-------
str
- `outfpath`
+ Path to BLAST results
Raises
- -------
+ ------
FileNotFoundError
`fasta` file does not exist
ValueError
provided `blast_type` is not 'blastp' or 'blastx'
- OSError
- Diamond execution failed
+ subprocess.CalledProcessError
+ Failed to run blast
"""
if not os.path.exists(fasta):
raise FileNotFoundError(fasta)
if os.path.exists(outfpath) and not force:
- empty = not os.stat(outfpath).st_size
- if not empty:
+ if os.path.getsize(outfpath):
if verbose:
logger.warning(f"FileExistsError: {outfpath}. To overwrite use --force")
return outfpath
@@ -213,10 +332,10 @@ def blast(
if blast_type not in ["blastp", "blastx"]:
raise ValueError(f"blast_type must be blastp or blastx. Given: {blast_type}")
if verbose:
- logger.debug(f"Diamond{blast_type.title()} {fasta} against {database}")
+ logger.debug(f"diamond {blast_type} {fasta} against {database}")
cmd = [
"diamond",
- "blastp",
+ blast_type,
"--query",
fasta,
"--db",
@@ -231,28 +350,33 @@ def blast(
"6",
"--out",
outfpath,
- "--tmpdir",
- tmpdir,
]
+ if tmpdir:
+ cmd.extend(["--tmpdir", tmpdir])
+ # this is done as when cmd is a list each element should be a string
cmd = [str(c) for c in cmd]
if verbose:
logger.debug(f'RunningDiamond: {" ".join(cmd)}')
- with open(os.devnull, "w") as stdout, open(os.devnull, "w") as stderr:
- proc = subprocess.run(cmd, stdout=stdout, stderr=stderr)
- if proc.returncode:
- raise OSError(f"DiamondFailed:\nArgs:{proc.args}\nReturnCode:{proc.returncode}")
+ subprocess.run(
+ cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
+ )
return outfpath
-def parse(results, top_pct=0.9, verbose=False):
- """Retrieve diamond results from output table
+def parse(results, bitscore_filter=0.9, verbose=False):
+ """
+ Retrieve diamond results from output table
Parameters
----------
results : str
-
- top_pct : 0 < float <= 1
- bitscore filter applied to each qseqid (the default is 0.9).
+ Path to BLASTP output file in outfmt9
+ bitscore_filter : 0 < float <= 1, optional
+ Bitscore filter applied to each sseqid, by default 0.9
+ Used to determine whether the bitscore is above a threshold value.
+ For example, if it is 0.9 then only bitscores >= 0.9 * the top bitscore are accepted
+ verbose : bool, optional
+ log progress to terminal, by default False
Returns
-------
@@ -264,7 +388,7 @@ def parse(results, top_pct=0.9, verbose=False):
FileNotFoundError
diamond results table does not exist
ValueError
- top_pct value is not a float or not in range of 0 to 1
+ bitscore_filter value is not a float or not in range of 0 to 1
"""
disable = False if verbose else True
# boolean toggle --> keeping above vs. below because I think this is more readable.
@@ -274,14 +398,14 @@ def parse(results, top_pct=0.9, verbose=False):
if not os.path.exists(results):
raise FileNotFoundError(results)
try:
- float(top_pct)
- except ValueError as err:
+ float(bitscore_filter)
+ except ValueError:
raise ValueError(
- f"top_pct must be a float! Input: {top_pct} Type: {type(top_pct)}"
+ f"bitscore_filter must be a float! Input: {bitscore_filter} Type: {type(bitscore_filter)}"
)
- in_range = 0.0 < top_pct <= 1.0
+ in_range = 0.0 < bitscore_filter <= 1.0
if not in_range:
- raise ValueError(f"top_pct not in range(0,1)! Input: {top_pct}")
+ raise ValueError(f"bitscore_filter not in range(0,1)! Input: {bitscore_filter}")
hits = {}
temp = set()
n_lines = file_length(results) if verbose else None
@@ -321,13 +445,17 @@ def parse(results, top_pct=0.9, verbose=False):
topbitscore = bitscore
temp = set([hit.qseqid])
continue
- if bitscore >= top_pct * topbitscore:
+ if bitscore >= bitscore_filter * topbitscore:
hits[hit.qseqid] += hit
return hits
def add_taxids(hits, database, verbose=True):
- """Translates accessions to taxid translations from prot.accession2taxid.gz
+ """
+ Translates accessions to taxids from prot.accession2taxid.gz. If an accession number is no
+ longer available in prot.accesssion2taxid.gz (either due to being suppressed, deprecated or
+ removed by NCBI), then None is returned as the taxid for the corresponsing sseqid.
+
# TODO: Should maybe write a wrapper for this to run on all of the NCBI
databases listed below... Maybe this will help account for instances where the
@@ -347,7 +475,9 @@ def add_taxids(hits, database, verbose=True):
hits : dict
{qseqid: DiamondResult, ...}
database : str
-
+ Path to prot.accession2taxid.gz database
+ verbose : bool, optional
+ log progress to terminal, by default False
Returns
-------
@@ -358,13 +488,6 @@ def add_taxids(hits, database, verbose=True):
-------
FileNotFoundError
prot.accession2taxid.gz database is required for translation taxid
- DatabasesOutOfDateError
- prot.accession2taxid.gz database and nr.dmnd are out of sync resulting
- in accessions that are no longer available (either due to being
- suppressed, deprecated or removed by NCBI). This must be resolved by
- updating both nr and prot.accession2taxid.gz and re-running diamond on
- the new nr.dmnd database. Alternatively, can try to find the exceptions in merged.dmp
-
# TODO: Replace file_length func for database file.
(in this case 808,717,857 lines takes ~15 minutes simply to read each line...)
@@ -379,18 +502,17 @@ def add_taxids(hits, database, verbose=True):
accessions = set(
chain.from_iterable([hit.sseqids.keys() for qseqid, hit in hits.items()])
)
- fh = gzip.open(database) if database.endswith(".gz") else open(database)
- __ = fh.readline()
+ # "rt" open the database in text mode instead of binary. Now it can be handled like a text file
+ fh = gzip.open(database, "rt") if database.endswith(".gz") else open(database)
+ __ = fh.readline() # remove the first line as it just gives the description
if verbose:
logger.debug(
f"Searching for {len(accessions):,} accessions in {os.path.basename(database)}. This may take a while..."
)
- is_gzipped = True if database.endswith(".gz") else False
n_lines = file_length(database) if verbose else None
desc = f"Parsing {os.path.basename(database)}"
acc2taxids = {}
for line in tqdm(fh, disable=disable, desc=desc, total=n_lines, leave=False):
- line = line.decode() if is_gzipped else line
acc_num, acc_ver, taxid, _ = line.split("\t")
taxid = int(taxid)
if acc_num in accessions:
@@ -405,48 +527,23 @@ def add_taxids(hits, database, verbose=True):
):
for sseqid in hit.sseqids:
taxid = acc2taxids.get(sseqid)
- # if taxid is None:
- # raise DatabasesOutOfDateError(f'{sseqid} deprecated/suppressed/removed')
hit.sseqids[sseqid].update({"taxid": taxid})
return hits
-def main(args):
- result = blast(
- fasta=args.fasta,
- database=args.database,
- outfpath=args.outfile,
- blast_type=args.blast_type,
- evalue=args.evalue,
- maxtargetseqs=args.maxtargetseqs,
- cpus=args.cpus,
- tmpdir=args.tmpdir,
- force=args.force,
- verbose=args.verbose,
- )
- hits = parse(results=result, top_pct=args.top_pct, verbose=args.verbose)
- hits = add_taxids(hits=hits, database=args.acc2taxids, verbose=args.verbose)
- fname, __ = os.path.splitext(os.path.basename(args.outfile))
- dirpath = os.path.dirname(os.path.realpath(args.outfile))
- hits_fname = ".".join([fname, "pkl.gz"])
- hits_fpath = os.path.join(dirpath, hits_fname)
- pickled_fpath = make_pickle(obj=hits, outfpath=hits_fpath)
- logger.debug(f"{len(hits):,} diamond hits serialized to {pickled_fpath}")
-
-
-if __name__ == "__main__":
+def main():
import argparse
- import os
parser = argparse.ArgumentParser(
description="""
Retrieves blastp hits with provided input assembly
- """
+ """,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
- parser.add_argument("fasta", help="")
- parser.add_argument("database", help="")
- parser.add_argument("acc2taxids", help="")
- parser.add_argument("outfile", help="")
+ parser.add_argument("fasta", help="Path to fasta file having the query sequences")
+ parser.add_argument("database", help="Path to diamond formatted database")
+ parser.add_argument("acc2taxids", help="Path to prot.accession2taxid.gz database")
+ parser.add_argument("outfile", help="Path to output file")
parser.add_argument(
"blast_type",
help="[blastp]: A.A -> A.A. [blastx]: Nucl. -> A.A.",
@@ -462,14 +559,46 @@ def main(args):
default=200,
type=int,
)
- parser.add_argument("--cpus", help="num cpus to use", default=0, type=int)
- parser.add_argument("--tmpdir", help="", default=os.curdir)
parser.add_argument(
- "--top-pct", help="top percentage of hits to retrieve", default=0.9
+ "--cpus", help="number of processors to use", default=mp.cpu_count(), type=int
+ )
+ parser.add_argument(
+ "--tmpdir",
+ help="Path to directory which will be used for temporary storage by diamond",
+ )
+ parser.add_argument(
+ "--bitscore-filter",
+ help="hits with bitscore greater than bitscr-filter*top bitscore will be kept",
+ default=0.9,
)
parser.add_argument(
"--force", help="force overwrite of diamond output table", action="store_true"
)
parser.add_argument("--verbose", help="add verbosity", action="store_true")
args = parser.parse_args()
- main(args)
+ result = blast(
+ fasta=args.fasta,
+ database=args.database,
+ outfpath=args.outfile,
+ blast_type=args.blast_type,
+ evalue=args.evalue,
+ maxtargetseqs=args.maxtargetseqs,
+ cpus=args.cpus,
+ tmpdir=args.tmpdir,
+ force=args.force,
+ verbose=args.verbose,
+ )
+ hits = parse(
+ results=result, bitscore_filter=args.bitscore_filter, verbose=args.verbose
+ )
+ hits = add_taxids(hits=hits, database=args.acc2taxids, verbose=args.verbose)
+ fname, __ = os.path.splitext(os.path.basename(args.outfile))
+ dirpath = os.path.dirname(os.path.realpath(args.outfile))
+ hits_fname = ".".join([fname, "pkl.gz"])
+ hits_fpath = os.path.join(dirpath, hits_fname)
+ pickled_fpath = make_pickle(obj=hits, outfpath=hits_fpath)
+ logger.debug(f"{len(hits):,} diamond hits serialized to {pickled_fpath}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/autometa/common/utilities.py b/autometa/common/utilities.py
index e7951b19f..dd0d7f72d 100644
--- a/autometa/common/utilities.py
+++ b/autometa/common/utilities.py
@@ -29,9 +29,12 @@
import logging
import os
import pickle
+import sys
import tarfile
import time
+import numpy as np
+
from functools import wraps
@@ -101,8 +104,8 @@ def make_pickle(obj, outfpath):
return outfpath
-def gunzip(infpath, outfpath):
- """Decompress gzipped `infpath` to `outfpath`.
+def gunzip(infpath, outfpath, delete_original=False, block_size=65536):
+ """Decompress gzipped `infpath` to `outfpath` and write checksum of `outfpath` upon successful decompression.
Parameters
----------
@@ -110,6 +113,10 @@ def gunzip(infpath, outfpath):
outfpath : str
+ delete_original : bool
+ Will delete the original file after successfully decompressing `infpath` (Default is False).
+ block_size : int
+ Amount of `infpath` to read in to memory before writing to `outfpath` (Default is 65536 bytes).
Returns
-------
@@ -125,15 +132,21 @@ def gunzip(infpath, outfpath):
logger.debug(
f"gunzipping {os.path.basename(infpath)} to {os.path.basename(outfpath)}"
)
- if os.path.exists(outfpath) and os.stat(outfpath).st_size > 0:
+ if os.path.exists(outfpath) and os.path.getsize(outfpath) > 0:
raise FileExistsError(outfpath)
lines = ""
- with gzip.open(infpath) as fh:
- for line in fh:
- lines += line.decode()
- with open(outfpath, "w") as out:
+ with gzip.open(infpath, "rt") as fh, open(outfpath, "w") as out:
+ for i, line in enumerate(fh):
+ lines += line
+ if sys.getsizeof(lines) >= block_size:
+ out.write(lines)
+ lines = ""
out.write(lines)
logger.debug(f"gunzipped {infpath} to {outfpath}")
+ write_checksum(outfpath, f"{outfpath}.md5")
+ if delete_original:
+ os.remove(infpath)
+ logger.debug(f"removed original file: {infpath}")
return outfpath
@@ -149,7 +162,7 @@ def untar(tarchive, outdir, member=None):
outdir : str
- member : str
+ member : str, optional
member file to extract.
Returns
@@ -165,6 +178,7 @@ def untar(tarchive, outdir, member=None):
`tarchive` is not a tar archive
KeyError
`member` was not found in `tarchive`
+
"""
if not member and not outdir:
raise ValueError(
@@ -172,7 +186,7 @@ def untar(tarchive, outdir, member=None):
)
logger.debug(f"decompressing tarchive {tarchive} to {outdir}")
outfpath = os.path.join(outdir, member) if member else None
- if member and os.path.exists(outfpath) and os.stat(outfpath).st_size > 0:
+ if member and os.path.exists(outfpath) and os.path.getsize(outfpath) > 0:
raise FileExistsError(outfpath)
if not tarfile.is_tarfile(tarchive):
raise ValueError(f"{tarchive} is not a tar archive")
@@ -219,6 +233,7 @@ def tarchive_results(outfpath, src_dirpath):
-------
FileExistsError
`outfpath` already exists
+
"""
logger.debug(f"tar archiving {src_dirpath} to {outfpath}")
if os.path.exists(outfpath):
@@ -229,16 +244,17 @@ def tarchive_results(outfpath, src_dirpath):
return outfpath
-def file_length(fpath):
+def file_length(fpath, approximate=False):
"""Retrieve the number of lines in `fpath`
- See:
- https://stackoverflow.com/questions/845058/how-to-get-line-count-of-a-large-file-cheaply-in-python
+ See: https://stackoverflow.com/q/845058/13118765
Parameters
----------
fpath : str
Description of parameter `fpath`.
+ approximate: bool
+ If True, will approximate the length of the file from the file size.
Returns
-------
@@ -253,18 +269,28 @@ def file_length(fpath):
"""
if not os.path.exists(fpath):
raise FileNotFoundError(fpath)
- if fpath.endswith(".gz"):
- fh = gzip.open(fpath, "rb")
- else:
- fh = open(fpath, "rb")
+
+ fh = gzip.open(fpath, "rt") if fpath.endswith(".gz") else open(fpath, "rb")
+ if approximate:
+ lines = []
+ n_sample_lines = 100000
+ for i, l in enumerate(fh):
+ if i > n_sample_lines:
+ break
+ lines.append(sys.getsizeof(l))
+ fh.close()
+ avg_size_per_line = np.average(lines)
+ total_size = os.path.getsize(fpath)
+ return int(np.ceil(total_size / avg_size_per_line))
+
for i, l in enumerate(fh):
pass
fh.close()
return i + 1
-def get_checksum(fpath):
- """Retrieve sha256 checksums from provided `args`.
+def calc_checksum(fpath):
+ """Retrieve md5 checksum from provided `fpath`.
See:
https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file
@@ -277,7 +303,8 @@ def get_checksum(fpath):
Returns
-------
str
- hexdigest of `fpath` using sha256
+ space-delimited hexdigest of `fpath` using md5sum and basename of `fpath`.
+ e.g. 'hash filename\n'
Raises
-------
@@ -285,10 +312,11 @@ def get_checksum(fpath):
Provided `fpath` does not exist
TypeError
`fpath` is not a string
+
"""
- def sha(block):
- hasher = hashlib.sha256()
+ def md5sum(block):
+ hasher = hashlib.md5()
for bytes in block:
hasher.update(bytes)
return hasher.hexdigest()
@@ -300,14 +328,78 @@ def blockiter(fh, blocksize=65536):
yield block
block = fh.read(blocksize)
- if type(fpath) != str:
+ if not isinstance(fpath, str):
raise TypeError(type(fpath))
if not os.path.exists(fpath):
raise FileNotFoundError(fpath)
fh = open(fpath, "rb")
- cksum = sha(blockiter(fh))
+ hash = md5sum(blockiter(fh))
fh.close()
- return cksum
+ return f"{hash} {os.path.basename(fpath)}\n"
+
+
+def read_checksum(fpath):
+ """Read checksum from provided checksum formatted `fpath`.
+
+ Note: See `write_checksum` for how a checksum file is generated.
+
+ Parameters
+ ----------
+ fpath : str
+
+
+ Returns
+ -------
+ str
+ checksum retrieved from `fpath`.
+
+ Raises
+ -------
+ TypeError
+ Provided `fpath` was not a string.
+ FileNotFoundError
+ Provided `fpath` does not exist.
+
+ """
+ if not isinstance(fpath, str):
+ raise TypeError(type(fpath))
+ if not os.path.exists(fpath):
+ raise FileNotFoundError(fpath)
+ with open(fpath) as fh:
+ return fh.readline()
+
+
+def write_checksum(infpath, outfpath):
+ """Calculate checksum for `infpath` and write to `outfpath`.
+
+ Parameters
+ ----------
+ infpath : str
+
+ outfpath : str
+
+
+ Returns
+ -------
+ NoneType
+ Description of returned object.
+
+ Raises
+ -------
+ FileNotFoundError
+ Provided `infpath` does not exist
+ TypeError
+ `infpath` or `outfpath` is not a string
+
+ """
+ if not os.path.exists(infpath):
+ raise FileNotFoundError(infpath)
+ if not isinstance(outfpath, str):
+ raise TypeError(type(outfpath))
+ checksum = calc_checksum(infpath)
+ with open(outfpath, "w") as fh:
+ fh.write(checksum)
+ logger.debug(f"Wrote {infpath} checksum to {outfpath}")
def valid_checkpoint(checkpoint_fp, fpath):
@@ -331,9 +423,10 @@ def valid_checkpoint(checkpoint_fp, fpath):
Either `fpath` or `checkpoint_fp` does not exist
TypeError
Either `fpath` or `checkpoint_fp` is not a string
+
"""
for fp in [checkpoint_fp, fpath]:
- if not type(fp) is str:
+ if not isinstance(fp, str):
raise TypeError(f"{fp} is type: {type(fp)}")
if not os.path.exists(fp):
raise FileNotFoundError(fp)
@@ -345,8 +438,8 @@ def valid_checkpoint(checkpoint_fp, fpath):
# If filepaths never match, prev_chksum and new_chksum will not match.
# Giving expected result.
break
- new_chksum = get_checksum(fpath)
- return True if new_chksum == prev_chksum else False
+ new_chksum = calc_checksum(fpath)
+ return new_chksum == prev_chksum
def get_checkpoints(checkpoint_fp, fpaths=None):
@@ -358,7 +451,7 @@ def get_checkpoints(checkpoint_fp, fpaths=None):
----------
checkpoint_fp : str
- fpaths : [str, ...]
+ fpaths : [str, ...], optional
[, ...]
Returns
@@ -371,6 +464,7 @@ def get_checkpoints(checkpoint_fp, fpaths=None):
ValueError
When `checkpoint_fp` first being written, will not populate an empty checkpoints file.
Raises an error if the `fpaths` list is empty or None
+
"""
if not os.path.exists(checkpoint_fp):
logger.debug(f"{checkpoint_fp} not found... Writing")
@@ -381,10 +475,10 @@ def get_checkpoints(checkpoint_fp, fpaths=None):
outlines = ""
for fpath in fpaths:
try:
- checksum = get_checksum(fpath)
+ checksum = calc_checksum(fpath)
except FileNotFoundError as err:
checksum = ""
- outlines += f"{checksum}\t{fpath}\n"
+ outlines += checksum
with open(checkpoint_fp, "w") as fh:
fh.write(outlines)
logger.debug(f"Written: {checkpoint_fp}")
@@ -412,20 +506,19 @@ def update_checkpoints(checkpoint_fp, fpath):
-------
dict
{fp:checksum, ...}
+
"""
checkpoints = get_checkpoints(checkpoint_fp)
if valid_checkpoint(checkpoint_fp, fpath):
return checkpoints
- new_checksum = get_checksum(fpath)
+ new_checksum = calc_checksum(fpath)
checkpoints.update({fpath: new_checksum})
outlines = ""
for fp, chk in checkpoints.items():
outlines += f"{chk}\t{fp}\n"
with open(checkpoint_fp, "w") as fh:
fh.write(outlines)
- logger.debug(
- f"Updated checkpoints with {os.path.basename(fpath)} -> {new_checksum[:16]}"
- )
+ logger.debug(f"Checkpoints updated: {new_checksum[:16]} {os.path.basename(fpath)}")
return checkpoints
@@ -469,13 +562,9 @@ def wrapper(*args, **kwds):
if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(
- description="file containing utilities functions for Autometa pipeline"
+ print(
+ "This file contains utilities for Autometa pipeline and should not be run directly!"
)
- print("file containing utilities functions for Autometa pipeline")
- args = parser.parse_args()
import sys
- sys.exit(1)
+ sys.exit(0)
diff --git a/autometa/config/databases.py b/autometa/config/databases.py
index 19d65d805..2cd676baa 100644
--- a/autometa/config/databases.py
+++ b/autometa/config/databases.py
@@ -20,19 +20,21 @@
along with Autometa. If not, see .
COPYRIGHT
-Configuration handling for Autometa Databases.
+This file contains the Databases class responsible for configuration handling
+of Autometa Databases.
"""
import logging
import os
import requests
+import tempfile
import multiprocessing as mp
from configparser import ConfigParser
-from configparser import ExtendedInterpolation
from ftplib import FTP
+from glob import glob
from autometa.config import get_config
from autometa.config import DEFAULT_FPATH
@@ -40,6 +42,9 @@
from autometa.config import put_config
from autometa.config import AUTOMETA_DIR
from autometa.common.utilities import untar
+from autometa.common.utilities import calc_checksum
+from autometa.common.utilities import read_checksum
+from autometa.common.utilities import write_checksum
from autometa.common.external import diamond
from autometa.common.external import hmmer
@@ -50,10 +55,38 @@
class Databases:
- """docstring for Databases."""
+ """Database class containing methods to allow downloading/formatting/updating
+ Autometa database dependencies.
+
+ Parameters
+ ----------
+ config : config.ConfigParser
+ Config containing database dependency information.
+ (the default is DEFAULT_CONFIG).
+
+ dryrun : bool
+ Run through database checking without performing
+ downloads/formatting (the default is False).
+
+ nproc : int
+ Number of processors to use to perform database formatting.
+ (the default is mp.cpu_count()).
+
+ upgrade : bool
+ Overwrite existing databases with more up-to-date database
+ files. (the default is False).
+
+ Attributes
+ ----------
+ ncbi_dir : str markers_dir : str
+ SECTIONS : dict keys are `sections`
+ respective to database config sections and values are options within
+ the `sections`.
+
+ """
SECTIONS = {
- "ncbi": ["nodes", "names", "merged", "accession2taxid", "nr",],
+ "ncbi": ["nodes", "names", "merged", "accession2taxid", "nr"],
"markers": [
"bacteria_single_copy",
"bacteria_single_copy_cutoffs",
@@ -62,31 +95,51 @@ class Databases:
],
}
- def __init__(self, config=DEFAULT_CONFIG, dryrun=False, nproc=mp.cpu_count()):
- if type(config) is not ConfigParser:
- raise TypeError(f"config is not ConfigParser : {type(config)}")
- if type(dryrun) is not bool:
- raise TypeError(f"dryrun must be True or False. type: {type(dryrun)}")
-
- self.config = config
- self.dryrun = dryrun
- self.nproc = nproc
- self.prepare_sections()
- self.ncbi_dir = self.config.get("databases", "ncbi")
- self.markers_dir = self.config.get("databases", "markers")
+ def __init__(
+ self, config=DEFAULT_CONFIG, dryrun=False, nproc=mp.cpu_count(), upgrade=False,
+ ):
+ """
- @property
- def satisfied(self):
- return self.get_missing(validate=True)
+ At instantiation of Databases instance, if any of the respective
+ database directories do not exist, they will be created. This will be
+ reflected in the provided `config`.
- def prepare_sections(self):
- """Add database sections to 'databases' if missing.
+ Parameters
+ ----------
+ config : config.ConfigParser
+ Config containing database dependency information.
+ (the default is DEFAULT_CONFIG).
+ dryrun : bool
+ Run through database checking without performing downloads/
+ formatting (the default is False).
+
+ nproc : int
+ Number of processors to use to perform database formatting.
+ (the default is mp.cpu_count()).
+ upgrade : bool
+ Overwrite existing databases with more up-to-date database files.
+ (the default is False).
Returns
-------
- NoneType
+ databases.Databases: Databases object
+ instance of Databases class
"""
+ if not isinstance(config, ConfigParser):
+ raise TypeError(f"config is not ConfigParser : {type(config)}")
+ if not isinstance(dryrun, bool):
+ raise TypeError(f"dryrun must be boolean. type: {type(dryrun)}")
+
+ self.config = config
+ self.dryrun = dryrun
+ self.nproc = nproc
+ self.upgrade = upgrade
+ if self.config.get("common", "home_dir") == "None":
+ # neccessary if user not running databases through the user
+ # endpoint. where :func:`~autometa.config.init_default` would've
+ # been called.
+ self.config.set("common", "home_dir", AUTOMETA_DIR)
if not self.config.has_section("databases"):
self.config.add_section("databases")
for section in Databases.SECTIONS:
@@ -94,9 +147,94 @@ def prepare_sections(self):
continue
outdir = DEFAULT_CONFIG.get("databases", section)
self.config.set("databases", section, outdir)
+ self.ncbi_dir = self.config.get("databases", "ncbi")
+ self.markers_dir = self.config.get("databases", "markers")
+ for outdir in {self.ncbi_dir, self.markers_dir}:
+ if not os.path.exists(outdir):
+ os.makedirs(outdir)
+
+ def satisfied(self, section=None, compare_checksums=False):
+ """Determines whether all database dependencies are satisfied.
+
+ Parameters
+ ----------
+ section : str
+ section to retrieve for `checksums` section.
+ Choices include: 'ncbi' and 'markers'.
+ compare_checksums : bool, optional
+ Also check if database information is up-to-date with current
+ hosted databases. (default is False).
+
+ Returns
+ -------
+ bool
+ True if all database dependencies are satisfied, otherwise False.
+
+ """
+ any_missing = self.get_missing(section=section)
+ if compare_checksums:
+ any_invalid = self.compare_checksums(section=section)
+ else:
+ any_invalid = {}
+ return not any_missing and not any_invalid
+
+ def get_remote_checksum(self, section, option):
+ """Get the checksum from provided `section` respective to `option` in
+ `self.config`.
+
+ Parameters
+ ----------
+ section : str
+ section to retrieve for `checksums` section.
+ Choices include: 'ncbi' and 'markers'.
+ option : str
+ `option` in `checksums` section corresponding to the section
+ checksum file.
+
+ Returns
+ -------
+ str
+ checksum of remote md5 file. e.g. 'hash filename\n'
+
+ Raises
+ -------
+ ConnectionError
+ Failed to connect to host for provided `option`.
+
+ """
+ if section not in {"ncbi", "markers"}:
+ raise ValueError(
+ f'"section" must be "ncbi" or "markers". Provided: {section}'
+ )
+ if section == "ncbi":
+ host = self.config.get(section, "host")
+ ftp_fullpath = self.config.get("checksums", option)
+ chksum_fpath = ftp_fullpath.split(host)[-1]
+ with FTP(host) as ftp, tempfile.TemporaryFile() as fp:
+ ftp.login()
+ result = ftp.retrbinary(f"RETR {chksum_fpath}", fp.write)
+ if not result.startswith("226 Transfer complete"):
+ raise ConnectionError(f"{chksum_fpath} download failed")
+ ftp.quit()
+ fp.seek(0)
+ checksum = fp.read().decode()
+ elif section == "markers":
+ url = self.config.get("checksums", option)
+ with requests.Session() as session:
+ resp = session.get(url)
+ if not resp.ok:
+ raise ConnectionError(f"Failed to retrieve {url}")
+ checksum = resp.text
+ return checksum
def format_nr(self):
- """Format NCBI nr.gz database into diamond formatted database nr.dmnd.
+ """Construct a diamond formatted database (nr.dmnd) from `nr` option
+ in `ncbi` section in user config.
+
+ NOTE: The checksum 'nr.dmnd.md5' will only be generated if nr.dmnd
+ construction is successful. If the provided `nr` option in `ncbi` is
+ 'nr.gz' the database will be removed after successful database
+ formatting.
Returns
-------
@@ -105,23 +243,82 @@ def format_nr(self):
"""
db_infpath = self.config.get("ncbi", "nr")
+ db_infpath_md5 = f"{db_infpath}.md5"
db_outfpath = db_infpath.replace(".gz", ".dmnd")
- if not self.dryrun and not os.path.exists(db_outfpath):
- diamond.makedatabase(fasta=nr, database=db_infpath, nproc=self.nproc)
+
+ db_outfpath_exists = os.path.exists(db_outfpath)
+ if db_outfpath_exists:
+ db_outfpath_hash, __ = calc_checksum(db_outfpath).split()
+
+ remote_checksum_matches = False
+ current_nr_checksum_matches = False
+ # Check database and database checksum is up-to-date
+ if os.path.exists(db_infpath_md5) and db_outfpath_exists:
+ # Check if the current db md5 is up-to-date with the remote db md5
+ current_hash, __ = read_checksum(db_infpath_md5).split()
+ remote_hash, __ = self.get_remote_checksum("ncbi", "nr").split()
+ if remote_hash == current_hash:
+ remote_checksum_matches = True
+ # Check if the current db md5 matches the calc'd db checksum
+ if db_outfpath_hash == current_hash:
+ current_nr_checksum_matches = True
+
+ db_outfpath_md5 = f"{db_outfpath}.md5"
+ db_outfpath_md5_checksum_matches = False
+ if os.path.exists(db_outfpath_md5) and db_outfpath_exists:
+ db_outfpath_md5_hash, __ = read_checksum(db_outfpath_md5).split()
+ if db_outfpath_hash == db_outfpath_md5_hash:
+ db_outfpath_md5_checksum_matches = True
+
+ checksum_checks = ["nr.dmnd.md5", "nr.gz.md5", "remote nr.gz.md5"]
+ checksum_matches = [
+ db_outfpath_md5_checksum_matches,
+ current_nr_checksum_matches,
+ remote_checksum_matches,
+ ]
+ for checksum_match, checksum_check in zip(checksum_matches, checksum_checks):
+ # If the checksums do not match, we need to update the database file.
+ if checksum_match:
+ logger.debug(f"{checksum_check} checksum matches, skipping...")
+ self.config.set("ncbi", "nr", db_outfpath)
+ logger.debug(f"set ncbi nr: {db_outfpath}")
+ return
+ # Only update out-of-date db files if user wants to update via self.upgrade
+ if not self.upgrade and checksum_check == "remote nr.gz.md5":
+ return
+
+ diamond.makedatabase(fasta=db_infpath, database=db_outfpath, nproc=self.nproc)
+ # Write checksum for nr.dmnd
+ write_checksum(db_outfpath, db_outfpath_md5)
+
+ if os.path.basename(db_infpath) == "nr.gz":
+ # nr.gz will be removed after successful nr.dmnd construction
+ os.remove(db_infpath)
+
self.config.set("ncbi", "nr", db_outfpath)
logger.debug(f"set ncbi nr: {db_outfpath}")
def extract_taxdump(self):
- """Extract autometa required files from ncbi taxdump directory.
+ """Extract autometa required files from ncbi taxdump.tar.gz archive
+ into ncbi databases directory and update user config with extracted
+ paths.
- Extracts nodes.dmp, names.dmp and merged.dmp from taxdump.tar.gz
+ This only extracts nodes.dmp, names.dmp and merged.dmp from
+ taxdump.tar.gz if the files do not already exist. If `upgrade`
+ was originally supplied as `True` to the Databases instance, then the
+ previous files will be replaced by the new taxdump files.
+
+ After successful extraction of the files, a checksum will be written
+ of the archive for future checking.
Returns
-------
NoneType
+ Will update `self.config` section `ncbi` with options 'nodes',
+ 'names','merged'
"""
- taxdump = self.config.get("ncbi", "taxdump")
+ taxdump_fpath = self.config.get("ncbi", "taxdump")
taxdump_files = [
("nodes", "nodes.dmp"),
("names", "names.dmp"),
@@ -129,67 +326,121 @@ def extract_taxdump(self):
]
for option, fname in taxdump_files:
outfpath = os.path.join(self.ncbi_dir, fname)
- if not self.dryrun and not os.path.exists(outfpath):
- outfpath = untar(taxdump, self.ncbi_dir, fname)
+ if self.dryrun:
+ logger.debug(f"UPDATE (ncbi,{option}): {outfpath}")
+ self.config.set("ncbi", option, outfpath)
+ continue
+ # Only update the taxdump files if the user says to do an update.
+ if self.upgrade and os.path.exists(outfpath):
+ os.remove(outfpath)
+ # Only extract the taxdump files if this is not a "dryrun"
+ if not os.path.exists(outfpath):
+ outfpath = untar(taxdump_fpath, self.ncbi_dir, fname)
+ write_checksum(outfpath, f"{outfpath}.md5")
+
logger.debug(f"UPDATE (ncbi,{option}): {outfpath}")
self.config.set("ncbi", option, outfpath)
- def update_ncbi(self, options):
- """Update NCBI database files (taxdump.tar.gz and nr.gz).
+ def download_ncbi_files(self, options):
+ """Download NCBI database files.
Parameters
----------
- options : set
- Set of options to update
+ options : iterable
+ iterable containing options in 'ncbi' section to download.
Returns
-------
NoneType
- Description of returned object.
+ Will update provided `options` in `self.config`.
Raises
-------
+ subprocess.CalledProcessError
+ NCBI file download with rsync failed.
ConnectionError
- NCBI file download failed.
+ NCBI file checksums do not match after file transfer.
"""
- # Download required NCBI database files
- if not os.path.exists(self.ncbi_dir):
- os.makedirs(self.ncbi_dir)
- host = DEFAULT_CONFIG.get("ncbi", "host")
+ host = self.config.get("ncbi", "host")
for option in options:
- ftp_fullpath = DEFAULT_CONFIG.get("database_urls", option)
- ftp_fpath = ftp_fullpath.split(host)[-1]
- if self.config.has_option("ncbi", option):
+ ftp_fullpath = self.config.get("database_urls", option)
+
+ if (
+ self.config.has_option("ncbi", option)
+ and self.config.get("ncbi", option) is not None
+ ):
outfpath = self.config.get("ncbi", option)
else:
- outfname = os.path.basename(ftp_fpath)
+ outfname = os.path.basename(ftp_fullpath)
outfpath = os.path.join(self.ncbi_dir, outfname)
- if not self.dryrun and not os.path.exists(outfpath):
- with FTP(host) as ftp, open(outfpath, "wb") as fp:
- ftp.login()
- logger.debug(f"starting {option} download")
- result = ftp.retrbinary(f"RETR {ftp_fpath}", fp.write)
- if not result.startswith("226 Transfer complete"):
- raise ConnectionError(f"{option} download failed")
- ftp.quit()
+
logger.debug(f"UPDATE: (ncbi,{option}): {outfpath}")
self.config.set("ncbi", option, outfpath)
- # Extract/format respective NCBI files
- self.extract_taxdump()
- self.format_nr()
- def update_markers(self, options):
- """Update single-copy markers hmms and cutoffs.
+ if self.dryrun:
+ return
+
+ rsync_fpath = ftp_fullpath.replace("ftp", "rsync")
+ cmd = ["rsync", "--quiet", "--archive", rsync_fpath, outfpath]
+ logger.debug(f"starting {option} download")
+ subprocess.run(
+ cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
+ )
+ checksum_outfpath = f"{outfpath}.md5"
+ write_checksum(outfpath, checksum_outfpath)
+ current_checksum = read_checksum(checksum_outfpath)
+ current_hash, __ = current_checksum.split()
+ remote_checksum = self.get_remote_checksum("ncbi", option)
+ remote_hash, __ = remote_checksum.split()
+ if current_checksum != remote_hash:
+ raise ConnectionError(f"{option} download failed")
+ if "taxdump" in options:
+ self.extract_taxdump()
+ if "nr" in options:
+ self.format_nr()
+
+ def press_hmms(self):
+ """hmmpress markers hmm database files.
+
+ Returns
+ -------
+ NoneType
+
+ """
+ hmm_search_str = os.path.join(self.markers_dir, "*.h3?")
+ # First search for pressed hmms to remove from list to hmmpress
+ pressed_hmms = {
+ os.path.realpath(os.path.splitext(fp)[0])
+ for fp in glob(hmm_search_str)
+ if not fp.endswith(".md5")
+ }
+ # Now retrieve all hmms in markers directory
+ hmms = (
+ os.path.join(self.markers_dir, fn)
+ for fn in os.listdir(self.markers_dir)
+ if fn.endswith(".hmm")
+ )
+ # Filter by hmms not already pressed
+ hmms = (fpath for fpath in hmms if fpath not in pressed_hmms)
+ # Press hmms and write checksums of their indices
+ for hmm_fp in hmms:
+ hmmer.hmmpress(hmm_fp)
+ for index_fp in glob(f"{hmm_fp}.h3?"):
+ write_checksum(index_fp, f"{index_fp}.md5")
+
+ def download_markers(self, options):
+ """Download markers database files and amend user config to reflect this.
Parameters
----------
- options : set
- Description of parameter `options`.
+ options : iterable
+ iterable containing options in 'markers' section to download.
Returns
-------
NoneType
+ Will update provided `options` in `self.config`.
Raises
-------
@@ -197,105 +448,239 @@ def update_markers(self, options):
marker file download failed.
"""
- if not os.path.exists(self.markers_dir):
- os.makedirs(self.markers_dir)
for option in options:
- url = DEFAULT_CONFIG.get("database_urls", option)
+ # First retrieve the markers file url from `option` in `markers`
+ url = self.config.get("database_urls", option)
if self.config.has_option("markers", option):
outfpath = self.config.get("markers", option)
else:
outfname = os.path.basename(url)
outfpath = os.path.join(self.markers_dir, outfname)
+
if self.dryrun:
logger.debug(f"UPDATE: (markers,{option}): {outfpath}")
self.config.set("markers", option, outfpath)
continue
- with requests.Session() as session:
+
+ # Retrieve markers file and write contents to `outfpath`
+ with requests.Session() as session, open(outfpath, "w") as fh:
resp = session.get(url)
- if not resp.ok:
- raise ConnectionError(f"Failed to retrieve {url}")
- with open(outfpath, "w") as outfh:
- outfh.write(resp.text)
+ if not resp.ok:
+ raise ConnectionError(f"Failed to retrieve {url}")
+ fh.write(resp.text)
self.config.set("markers", option, outfpath)
- if outfpath.endswith(".hmm"):
- hmmer.hmmpress(outfpath)
+ checksum_outfpath = f"{outfpath}.md5"
+ write_checksum(outfpath, checksum_outfpath)
+ current_checksum = read_checksum(checksum_outfpath)
+ current_hash, __ = current_checksum.split()
+ remote_checksum = self.get_remote_checksum("markers", option)
+ remote_hash, __ = remote_checksum.split()
+ if current_checksum != remote_hash:
+ raise ConnectionError(f"{option} download failed")
+ self.press_hmms()
+
+ def get_missing(self, section=None):
+ """Get all missing database files in `options` from `sections`
+ in config.
+
+ Parameters
+ ----------
+ section : str, optional
+ Configure provided `section`. Choices include 'markers' and 'ncbi'.
+ (default will download/format all database directories)
+
+ Returns
+ -------
+ dict
+ {section:{option, option,...}, section:{...}, ...}
- def get_missing(self, validate=False):
- """Retrieve all database files from all database sections that are not
- available.
+ """
+ sections = [section] if section else Databases.SECTIONS.keys()
+ missing = {}
+ for section in sections:
+ for option in self.config.options(section):
+ # Skip user added options not required by Autometa
+ if option not in Databases.SECTIONS.get(section):
+ continue
+ fpath = self.config.get(section, option)
+ if os.path.exists(fpath):
+ continue
+ if section in missing:
+ missing[section].add(option)
+ else:
+ missing.update({section: set([option])})
+ # Log missing options
+ for section, options in missing.items():
+ for option in options:
+ logger.debug(f"MISSING: ({section},{option})")
+ return missing
+
+ def download_missing(self, section=None):
+ """Download missing Autometa database dependencies from provided `section`.
+ If no `section` is provided will check all sections.
+
+ Parameters
+ ----------
+ section : str, optional
+ Section to check for missing database files (the default is None).
+ Choices include 'ncbi' and 'markers'.
Returns
-------
- bool or dict
+ NoneType
+ Will update provided `section` in `self.config`.
- - if `validate` is True : bool
+ Raises
+ -------
+ ValueError
+ Provided `section` does not match 'ncbi' and 'markers'.
- all available evaluates to True, otherwise False
+ """
+ dispatcher = {
+ "ncbi": self.download_ncbi_files,
+ "markers": self.download_markers,
+ }
+ if section and section not in dispatcher:
+ raise ValueError(f'{section} does not match "ncbi" or "markers"')
+ if section:
+ missing = self.get_missing(section=section)
+ options = missing.get(section, [])
+ dispatcher[section](options)
+ else:
+ missing = self.get_missing()
+ for section, options in missing.items():
+ dispatcher[section](options)
- - if `validate` is False : dict
+ def compare_checksums(self, section=None):
+ """Get all invalid database files in `options` from `section`
+ in config. An md5 checksum comparison will be performed between the
+ current and file's remote md5 to ensure file integrity prior to
+ checking the respective file as valid.
- {section:{option, option,...}, section:{...}, ...}
+ Parameters
+ ----------
+ section : str, optional Configure provided `section` Choices include
+ 'markers' and 'ncbi'. (default will download/format all database
+ directories)
+
+ Returns
+ -------
+ dict {section:{option, option,...}, section:{...}, ...}
"""
- missing = {}
- for section in Databases.SECTIONS:
+ sections = [section] if section else Databases.SECTIONS.keys()
+ invalid = {}
+ taxdump_checked = False
+ for section in sections:
for option in self.config.options(section):
if option not in Databases.SECTIONS.get(section):
# Skip user added options not required by Autometa
continue
+ # nodes.dmp, names.dmp and merged.dmp are all in taxdump.tar.gz
+ option = "taxdump" if option in {"nodes", "names", "merged"} else option
fpath = self.config.get(section, option)
- if os.path.exists(fpath) and os.stat(fpath).st_size >= 0:
- # TODO: [Checkpoint validation]
- logger.debug(f"({section},{option}): {fpath}")
+ fpath_md5 = f"{fpath}.md5"
+ # We can not checksum a file that does not exist.
+ if not os.path.exists(fpath) and not os.path.exists(fpath_md5):
continue
- if validate:
- return False
- if section in missing:
- missing[section].add(option)
+ # To not waste time checking the taxdump files 3 times.
+ if option == "taxdump" and taxdump_checked:
+ continue
+ if os.path.exists(fpath_md5):
+ current_checksum = read_checksum(fpath_md5)
else:
- missing.update({section: set([option])})
- for section, opts in missing.items():
- for opt in opts:
- logger.debug(f"MISSING: ({section},{opt})")
- return True if validate else missing
+ current_checksum = calc_checksum(fpath)
+ current_hash, __ = current_checksum.split()
+ try:
+ remote_checksum = self.get_remote_checksum(section, option)
+ remote_hash, __ = remote_checksum.split()
+ except ConnectionError as err:
+ # Do not mark file as invalid if a connection error occurs.
+ logger.warning(err)
+ continue
+ if option == "taxdump":
+ taxdump_checked = True
+ if remote_hash == current_hash:
+ logger.debug(f"{option} checksums match, skipping...")
+ continue
+ if section in invalid:
+ invalid[section].add(option)
+ else:
+ invalid.update({section: set([option])})
+ # Log invalid options
+ for section, options in invalid.items():
+ for option in options:
+ logger.debug(f"INVALID: ({section},{option})")
+ return invalid
- def update_missing(self):
- """Download and format databases for all options in each section.
+ def fix_invalid_checksums(self, section=None):
+ """Download/Update/Format databases where checksums are out-of-date.
- NOTE: This will only perform the download and formatting if self.dryrun is False
+ Parameters
+ ----------
+ section : str, optional
+ Configure provided `section`. Choices include 'markers' and 'ncbi'.
+ (default will download/format all database directories)
Returns
-------
NoneType
- config updated with required missing sections.
+ Will update provided `options` in `self.config`.
+
+ Raises
+ -------
+ ConnectionError
+ Failed to connect to `section` host site.
"""
- dispatcher = {"ncbi": self.update_ncbi, "markers": self.update_markers}
- missing = self.get_missing()
- for section, options in missing.items():
- if section == "ncbi":
- if "nodes" in options or "names" in options or "merged" in options:
- options.discard("nodes")
- options.discard("names")
- options.discard("merged")
- options.add("taxdump")
+ dispatcher = {
+ "ncbi": self.download_ncbi_files,
+ "markers": self.download_markers,
+ }
+ if section and section not in dispatcher:
+ raise ValueError(f'{section} does not match "ncbi" or "markers"')
+ if section:
+ invalid = self.compare_checksums(section=section)
+ options = invalid.get(section, set())
dispatcher[section](options)
+ else:
+ invalid = self.compare_checksums()
+ for section, options in invalid.items():
+ dispatcher[section](options)
- def configure(self):
- """Checks database files
+ def configure(self, section=None, no_checksum=False):
+ """Configures Autometa's database dependencies by first checking missing
+ dependencies then comparing checksums to ensure integrity of files.
+
+ Download and format databases for all options in each section.
+
+ This will only perform the download and formatting if `self.dryrun` is
+ False. This will update out-of-date databases if `self.upgrade` is
+ True.
+
+ Parameters
+ ----------
+ section : str, optional Configure provided `section`. Choices include
+ 'markers' and 'ncbi'. (default will download/format all database
+ directories) no_checksum : bool, optional Do not perform checksum
+ comparisons (Default is False).
Returns
-------
- configparser.ConfigParser
- config with updated options in respective databases sections.
+ configparser.ConfigParser config with updated options in respective
+ databases sections.
Raises
-------
- ExceptionName
- Why the exception is raised.
+ ValueError Provided `section` does not match 'ncbi' or 'markers'.
+ ConnectionError A connection issue occurred when connecting to NCBI
+ or GitHub.
"""
- self.update_missing()
+ self.download_missing(section=section)
+ if no_checksum:
+ return self.config
+ self.fix_invalid_checksums(section=section)
return self.config
@@ -303,7 +688,6 @@ def main():
import argparse
import logging as logger
- cpus = mp.cpu_count()
logger.basicConfig(
format="[%(asctime)s %(levelname)s] %(name)s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
@@ -311,11 +695,13 @@ def main():
)
parser = argparse.ArgumentParser(
- description="databases config",
- epilog="By default, with no arguments, will download/format databases into default databases directory.",
+ description="Main script to configure Autometa database dependencies.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ epilog="By default, with no arguments, will download/format databases "
+ "into default databases directory.",
)
parser.add_argument(
- "--config", help="", default=DEFAULT_FPATH
+ "--config", help="", default=DEFAULT_FPATH,
)
parser.add_argument(
"--dryrun",
@@ -323,27 +709,81 @@ def main():
action="store_true",
default=False,
)
+ parser.add_argument(
+ "--update",
+ help="Update all out-of-date databases.",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--update-markers",
+ help="Update out-of-date markers databases.",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--update-ncbi",
+ help="Update out-of-date ncbi databases.",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--check-dependencies",
+ help="Check database dependencies are satisfied.",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--no-checksum",
+ help="Do not perform remote checksum comparisons to validate databases"
+ "are up-to-date.",
+ action="store_true",
+ default=False,
+ )
parser.add_argument(
"--nproc",
- help=f"num. cpus to use for DB formatting. (default {cpus})",
+ help="num. cpus to use for DB formatting.",
type=int,
- default=cpus,
+ default=mp.cpu_count(),
)
parser.add_argument("--out", help="")
args = parser.parse_args()
config = get_config(args.config)
- dbs = Databases(config=config, dryrun=args.dryrun, nproc=args.nproc)
- logger.debug(f"Configuring databases")
- config = dbs.configure()
- dbs = Databases(config=config, dryrun=args.dryrun, nproc=args.nproc)
- logger.info(f"Database dependencies satisfied: {dbs.satisfied}")
+ dbs = Databases(
+ config=config, dryrun=args.dryrun, nproc=args.nproc, upgrade=args.update,
+ )
+
+ compare_checksums = False
+ for update_section in [args.update, args.update_markers, args.update_ncbi]:
+ if update_section and not args.no_checksum:
+ compare_checksums = True
+
+ if args.update_markers:
+ section = "markers"
+ elif args.update_ncbi:
+ section = "ncbi"
+ else:
+ section = None
+
+ if args.check_dependencies:
+ dbs_satisfied = dbs.satisfied(
+ section=section, compare_checksums=compare_checksums
+ )
+ logger.info(f"Database dependencies satisfied: {dbs_satisfied}")
+ import sys
+
+ sys.exit(0)
+
+ # logger.debug(f'Configuring databases')
+ config = dbs.configure(section=section, no_checksum=args.no_checksum)
+
if not args.out:
import sys
sys.exit(0)
put_config(config, args.out)
- logger.debug(f"{args.out} written.")
+ logger.info(f"{args.out} written.")
if __name__ == "__main__":
diff --git a/autometa/config/default.config b/autometa/config/default.config
index 0b4a1ac38..7475f595c 100644
--- a/autometa/config/default.config
+++ b/autometa/config/default.config
@@ -56,15 +56,19 @@ markers = ${databases:base}/markers
taxdump = ftp://${ncbi:host}/pub/taxonomy/taxdump.tar.gz
accession2taxid = ftp://${ncbi:host}/pub/taxonomy/accession2taxid/prot.accession2taxid.gz
nr = ftp://${ncbi:host}/blast/db/FASTA/nr.gz
-bacteria_single_copy = https://github.com/WiscEvan/Autometa/raw/dev/databases/markers/bacteria.single_copy.hmm
-bacteria_single_copy_cutoffs = https://${markers:host}/WiscEvan/Autometa/dev/databases/markers/bacteria.single_copy.cutoffs?token=AGF3KQVL3J4STDT4TJQVDBS6GG5FE
-archaea_single_copy = https://github.com/WiscEvan/Autometa/raw/dev/databases/markers/archaea.single_copy.hmm
-archaea_single_copy_cutoffs = https://${markers:host}/WiscEvan/Autometa/dev/databases/markers/archaea.single_copy.cutoffs?token=AGF3KQXVUDFIH6ECVTYMZQS6GG5KO
+bacteria_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.hmm
+bacteria_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.cutoffs
+archaea_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.hmm
+archaea_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.cutoffs
[checksums]
taxdump = ftp://${ncbi:host}/pub/taxonomy/taxdump.tar.gz.md5
accession2taxid = ftp://${ncbi:host}/pub/taxonomy/accession2taxid/prot.accession2taxid.gz.md5
nr = ftp://${ncbi:host}/blast/db/FASTA/nr.gz.md5
+bacteria_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.hmm.md5
+bacteria_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.cutoffs.md5
+archaea_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.hmm.md5
+archaea_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.cutoffs.md5
[ncbi]
host = ftp.ncbi.nlm.nih.gov
@@ -98,8 +102,8 @@ bed = alignments.bed
length_filtered = metagenome.filtered.fna
coverages = coverages.tsv
kmer_counts = kmers.tsv
-kmer_normalized = kmers.normalized.tsv
-kmer_embedded = kmers.embedded.tsv
+kmer_normalized = kmers.normalized.tsv
+kmer_embedded = kmers.embedded.tsv
nucleotide_orfs = metagenome.filtered.orfs.fna
amino_acid_orfs = metagenome.filtered.orfs.faa
blastp = blastp.tsv
diff --git a/autometa/config/user.py b/autometa/config/user.py
index 77483f72d..caa599333 100644
--- a/autometa/config/user.py
+++ b/autometa/config/user.py
@@ -34,7 +34,7 @@
from autometa.common import utilities
from autometa.common.metagenome import Metagenome
-from autometa.common.mag import MAG
+from autometa.common.metabin import MetaBin
from autometa.config.databases import Databases
from autometa.config.project import Project
from autometa.common.utilities import timeit
diff --git a/autometa/databases/markers/archaea.single_copy.cutoffs.md5 b/autometa/databases/markers/archaea.single_copy.cutoffs.md5
new file mode 100644
index 000000000..0be5450fe
--- /dev/null
+++ b/autometa/databases/markers/archaea.single_copy.cutoffs.md5
@@ -0,0 +1 @@
+d0121bd751e0d454a81c54a0655ce745 archaea.single_copy.cutoffs
diff --git a/autometa/databases/markers/archaea.single_copy.hmm.md5 b/autometa/databases/markers/archaea.single_copy.hmm.md5
new file mode 100644
index 000000000..2a966fe72
--- /dev/null
+++ b/autometa/databases/markers/archaea.single_copy.hmm.md5
@@ -0,0 +1 @@
+fb0f644a2df855fa7145436564091e07 archaea.single_copy.hmm
diff --git a/autometa/databases/markers/bacteria.single_copy.cutoffs.md5 b/autometa/databases/markers/bacteria.single_copy.cutoffs.md5
new file mode 100644
index 000000000..fe76ec24c
--- /dev/null
+++ b/autometa/databases/markers/bacteria.single_copy.cutoffs.md5
@@ -0,0 +1 @@
+936e6cb46002caa08b3690b8446086fb bacteria.single_copy.cutoffs
diff --git a/autometa/databases/markers/bacteria.single_copy.hmm.md5 b/autometa/databases/markers/bacteria.single_copy.hmm.md5
new file mode 100644
index 000000000..44b324798
--- /dev/null
+++ b/autometa/databases/markers/bacteria.single_copy.hmm.md5
@@ -0,0 +1 @@
+eafdb5bd1447e814a4ee47ee11434ffe bacteria.single_copy.hmm
diff --git a/docs/parse_argparse.py b/docs/parse_argparse.py
index 08c5f1124..ab47b8974 100644
--- a/docs/parse_argparse.py
+++ b/docs/parse_argparse.py
@@ -103,6 +103,11 @@ def get_usage(argparse_lines):
-------
wrapped_lines : string
indented arparse output after running the `--help` command
+
+ Raises
+ ------
+ subprocess.CalledProcessError
+ Error while running --help on these argparse lines
"""
__, tmp_fpath = tempfile.mkstemp()
with open(tmp_fpath, "w") as outfh:
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 8d3f5c97f..bd9d48d07 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -19,7 +19,8 @@
autodoc_mock_imports = ["Bio", "hdbscan", "tsne", "sklearn", "umap", "tqdm"]
-import parse_argparse # nopep8
+# fmt: off
+import parse_argparse
# -- Project information -----------------------------------------------------
@@ -48,7 +49,8 @@
]
todo_include_todos = True
-
+# Includes doctrings of functions that begin with double underscore
+napoleon_include_special_with_doc = True
autosummary_generate = True
# Add any paths that contain templates here, relative to this directory.