diff --git a/.condarc.yaml b/.condarc.yaml deleted file mode 100644 index 90c78f428..000000000 --- a/.condarc.yaml +++ /dev/null @@ -1,6 +0,0 @@ -channels: - - defaults - - bioconda - - conda-forge -show_channels_urls: True -default_threads: 6 diff --git a/.github/ISSUE_TEMPLATE/support_request.md b/.github/ISSUE_TEMPLATE/support_request.md new file mode 100644 index 000000000..c36491bec --- /dev/null +++ b/.github/ISSUE_TEMPLATE/support_request.md @@ -0,0 +1,52 @@ +--- +name: Support request +about: Need clarification on a few scripts or the pipeline in general? Ask us! +--- + + + +### User checklist + +- [ ] Are you using the latest release? +- [ ] Are you using python 3? +- [ ] Did you check previous issues to see if this has already been mentioned? +- [ ] Are you using a Mac or Linux machine? + +#### Description + + + +#### Expected Behavior + + + +#### System Environment + + + +- Operating System: +- RAM: +- Disk: + +#### Tasks/Command(s) + + + +- [ ] Task 1 +- [ ] Task 2 +- [ ] Task 3 +- [ ] etc. + +
Log/Error information generated by Autometa.

+ + +``` + +``` + +

diff --git a/.gitignore b/.gitignore index de5adf94f..dca11f12a 100644 --- a/.gitignore +++ b/.gitignore @@ -138,13 +138,14 @@ dmypy.json autometa/*.pyc autometa/taxonomy/*.pyc autometa/databases/markers/*.h3* +autometa/databases/ncbi/* # databases / testing tests/data/* !tests/data/metagenome.fna # visualStudioCode -.vscode/* +.vscode/ !.vscode/settings.json !.vscode/tasks.json !.vscode/launch.json diff --git a/autometa/binning/recursive_dbscan.py b/autometa/binning/recursive_dbscan.py index 8cccd3fdf..f85825a0f 100644 --- a/autometa/binning/recursive_dbscan.py +++ b/autometa/binning/recursive_dbscan.py @@ -26,8 +26,11 @@ import logging import os +import shutil +import tempfile import pandas as pd +import numpy as np from Bio import SeqIO from sklearn.cluster import DBSCAN @@ -38,7 +41,7 @@ # TODO: This should be from autometa.common.kmers import Kmers # So later we can simply/and more clearly do Kmers.load(kmers_fpath).embed(method) -from autometa.common.exceptions import RecursiveDBSCANError +from autometa.common.exceptions import BinningError from autometa.taxonomy.ncbi import NCBI pd.set_option("mode.chained_assignment", None) @@ -47,14 +50,114 @@ logger = logging.getLogger(__name__) +def add_metrics(df, markers_df, domain="bacteria"): + """Adds the completeness and purity metrics to each respective contig in df. + + Parameters + ---------- + df : pd.DataFrame + index='contig' cols=['x','y','coverage','cluster'] + markers_df : pd.DataFrame + wide format, i.e. index=contig cols=[marker,marker,...] + domain : str, optional + Kingdom to determine metrics (the default is 'bacteria'). + choices=['bacteria','archaea'] + + Returns + ------- + 2-tuple + `df` with added cols=['completeness', 'purity'] + pd.DataFrame(index=clusters, cols=['completeness', 'purity']) + + Raises + ------- + KeyError + `domain` is not "bacteria" or "archaea" + + """ + domain = domain.lower() + marker_sets = {"bacteria": 139.0, "archaea": 162.0} + if domain not in marker_sets: + raise KeyError(f"{domain} is not bacteria or archaea!") + expected_number = marker_sets[domain] + clusters = dict(list(df.groupby("cluster"))) + metrics = {"purity": {}, "completeness": {}} + for cluster, dff in clusters.items(): + contigs = dff.index.tolist() + summed_markers = markers_df[markers_df.index.isin(contigs)].sum() + is_present = summed_markers >= 1 + is_single_copy = summed_markers == 1 + nunique_markers = summed_markers[is_present].index.nunique() + num_single_copy_markers = summed_markers[is_single_copy].index.nunique() + completeness = nunique_markers / expected_number * 100 + # Protect from divide by zero + if nunique_markers == 0: + purity = pd.NA + else: + purity = num_single_copy_markers / nunique_markers * 100 + metrics["completeness"].update({cluster: completeness}) + metrics["purity"].update({cluster: purity}) + metrics_df = pd.DataFrame(metrics, index=clusters.keys()) + merged_df = pd.merge(df, metrics_df, left_on="cluster", right_index=True) + return merged_df, metrics_df + + +def run_dbscan(df, eps, dropcols=["cluster", "purity", "completeness"]): + """Run clustering on `df` at provided `eps`. + + Notes + ----- + + * documentation for sklearn `DBSCAN `_ + * documentation for `HDBSCAN `_ + + Parameters + ---------- + df : pd.DataFrame + Contigs with embedded k-mer frequencies as ['x','y'] columns and optionally 'coverage' column + eps : float + The maximum distance between two samples for one to be considered + as in the neighborhood of the other. This is not a maximum bound + on the distances of points within a cluster. This is the most + important DBSCAN parameter to choose appropriately for your data set + and distance function. See `DBSCAN docs `_ for more details. + dropcols : list, optional + Drop columns in list from `df` + (the default is ['cluster','purity','completeness']). + + Returns + ------- + pd.DataFrame + `df` with 'cluster' column added + + Raises + ------- + ValueError + sets `usecols` and `dropcols` may not share elements + + """ + for col in dropcols: + if col in df.columns: + df.drop(columns=col, inplace=True) + n_samples = df.shape[0] + if n_samples == 1: + clusters = pd.Series([pd.NA], index=df.index, name="cluster") + return pd.merge(df, clusters, how="left", left_index=True, right_index=True) + cols = ["x", "y"] + if "coverage" in df.columns: + cols.append("coverage") + if np.any(df.isnull()): + raise BinningError( + f"df is missing {df.isnull().sum().sum()} kmer/coverage annotations" + ) + X = df.loc[:, cols].to_numpy() + clusterer = DBSCAN(eps=eps, min_samples=1, n_jobs=-1).fit(X) + clusters = pd.Series(clusterer.labels_, index=df.index, name="cluster") + return pd.merge(df, clusters, how="left", left_index=True, right_index=True) + + def recursive_dbscan( - table, - markers_df, - domain, - completeness_cutoff, - purity_cutoff, - verbose=False, - method="DBSCAN", + table, markers_df, domain, completeness_cutoff, purity_cutoff, verbose=False, ): """Carry out DBSCAN, starting at eps=0.3 and continuing until there is just one group. @@ -83,9 +186,6 @@ def recursive_dbscan( `purity_cutoff` threshold to retain cluster (the default is 90.0). verbose : bool log stats for each recursive_dbscan clustering iteration. - method : str - clustering `method` to perform. (the default is 'DBSCAN'). - Choices=['DBSCAN','HDBSCAN'] NOTE: 'HDBSCAN' is still under development. Returns ------- @@ -105,11 +205,11 @@ def recursive_dbscan( best_median = float("-inf") best_df = pd.DataFrame() while n_clusters > 1: - binned_df = run_dbscan(table, eps, method=method) - df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain) - completess_filter = df["completeness"] >= completeness_cutoff - purity_filter = df["purity"] >= purity_cutoff - median_completeness = df[completess_filter & purity_filter][ + binned_df = run_dbscan(table, eps) + df, metrics_df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain) + completeness_filter = metrics_df["completeness"] >= completeness_cutoff + purity_filter = metrics_df["purity"] >= purity_cutoff + median_completeness = metrics_df[completeness_filter & purity_filter][ "completeness" ].median() if median_completeness >= best_median: @@ -139,9 +239,9 @@ def recursive_dbscan( logger.debug("No complete or pure clusters found") return pd.DataFrame(), table - completess_filter = best_df["completeness"] >= completeness_cutoff + completeness_filter = best_df["completeness"] >= completeness_cutoff purity_filter = best_df["purity"] >= purity_cutoff - complete_and_pure_df = best_df.loc[completess_filter & purity_filter] + complete_and_pure_df = best_df.loc[completeness_filter & purity_filter] unclustered_df = best_df.loc[~best_df.index.isin(complete_and_pure_df.index)] if verbose: logger.debug(f"Best completeness median: {best_median:4.2f}") @@ -151,30 +251,36 @@ def recursive_dbscan( return complete_and_pure_df, unclustered_df -def run_dbscan( +def run_hdbscan( df, - eps, + min_cluster_size, + min_samples, + cache_dir=None, dropcols=["cluster", "purity", "completeness"], - usecols=["x", "y"], - method="DBSCAN", ): - """Run clustering via `method`. + """Run clustering on `df` at provided `min_cluster_size`. + + Notes + ----- + + * reasoning for parameter: `cluster_selection_method `_ + * reasoning for parameters: `min_cluster_size and min_samples `_ + * documentation for `HDBSCAN `_ Parameters ---------- df : pd.DataFrame - Description of parameter `df`. - eps : float - Description of parameter `eps`. - dropcols : list - Drop columns in list from `df` (the default is ['cluster','purity','completeness']). - usecols : list - Use columns in list for `df`. - The default is ['x','y','coverage'] if 'coverage' exists in df.columns. - else ['x','y','z']. - method : str - clustering `method` to perform. (the default is 'DBSCAN'). - Choices=['DBSCAN','HDBSCAN'] NOTE: 'HDBSCAN' is still under development + Contigs with embedded k-mer frequencies as ['x','y'] columns and optionally 'coverage' column + min_cluster_size : int + The minimum size of clusters; single linkage splits that contain + fewer points than this will be considered points "falling out" of a + cluster rather than a cluster splitting into two new clusters. + min_samples : int + The number of samples in a neighborhood for a point to be + considered a core point. + dropcols : list, optional + Drop columns in list from `df` + (the default is ['cluster','purity','completeness']). Returns ------- @@ -185,74 +291,133 @@ def run_dbscan( ------- ValueError sets `usecols` and `dropcols` may not share elements - ValueError - Method is not an available choice: choices=['DBSCAN','HDBSCAN'] + """ + for col in dropcols: + if col in df.columns: + df.drop(columns=col, inplace=True) n_samples = df.shape[0] if n_samples == 1: clusters = pd.Series([pd.NA], index=df.index, name="cluster") return pd.merge(df, clusters, how="left", left_index=True, right_index=True) - for col in dropcols: - if col in df.columns: - df.drop(columns=col, inplace=True) cols = ["x", "y"] - cols.append("coverage") if "coverage" in df.columns else cols.append("z") + if "coverage" in df.columns: + cols.append("coverage") + if np.any(df.isnull()): + raise BinningError( + f"df is missing {df.isnull().sum().sum()} kmer/coverage annotations" + ) X = df.loc[:, cols].to_numpy() - if method == "DBSCAN": - clustering = DBSCAN(eps=eps, min_samples=1, n_jobs=-1).fit(X) - elif method == "HDBSCAN": - clustering = HDBSCAN( - cluster_selection_epsilon=eps, - min_cluster_size=2, - min_samples=1, - allow_single_cluster=True, - gen_min_span_tree=True, - ).fit(X) - else: - raise ValueError(f"Method: {method} not a choice. choose b/w DBSCAN & HDBSCAN") - clusters = pd.Series(clustering.labels_, index=df.index, name="cluster") + clusterer = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + cluster_selection_method="leaf", + allow_single_cluster=True, + memory=cache_dir, + ).fit(X) + clusters = pd.Series(clusterer.labels_, index=df.index, name="cluster") return pd.merge(df, clusters, how="left", left_index=True, right_index=True) -def add_metrics(df, markers_df, domain="bacteria"): - """Adds the completeness and purity metrics to each respective contig in df. +def recursive_hdbscan( + table, markers_df, domain, completeness_cutoff, purity_cutoff, verbose=False, +): + """Recursively run HDBSCAN starting with defaults and iterating the min_samples + and min_cluster_size until only 1 cluster is recovered. Parameters ---------- - df : pd.DataFrame - Description of parameter `df`. + table : pd.DataFrame + Contigs with embedded k-mer frequencies as ['x','y','z'] columns and + optionally 'coverage' column markers_df : pd.DataFrame wide format, i.e. index=contig cols=[marker,marker,...] domain : str Kingdom to determine metrics (the default is 'bacteria'). choices=['bacteria','archaea'] + completeness_cutoff : float + `completeness_cutoff` threshold to retain cluster (the default is 20.0). + purity_cutoff : float + `purity_cutoff` threshold to retain cluster (the default is 90.0). + verbose : bool + log stats for each recursive_dbscan clustering iteration. Returns ------- - pd.DataFrame - DataFrame with added columns - 'completeness' and 'purity' + 2-tuple + (pd.DataFrame(), pd.DataFrame()) + DataFrames consisting of contigs that passed/failed clustering + cutoffs, respectively. + + DataFrame: + index = contig + columns = ['x,'y','z','coverage','cluster','purity','completeness'] """ - marker_sets = {"bacteria": 139.0, "archaea": 162.0} - expected_number = marker_sets.get(domain.lower(), 139) - clusters = dict(list(df.groupby("cluster"))) - metrics = {"purity": {}, "completeness": {}} - for cluster, dff in clusters.items(): - contigs = dff.index.tolist() - summed_markers = markers_df[markers_df.index.isin(contigs)].sum() - is_present = summed_markers >= 1 - is_single_copy = summed_markers == 1 - nunique_markers = summed_markers[is_present].index.nunique() - num_single_copy_markers = summed_markers[is_single_copy].index.nunique() - completeness = nunique_markers / expected_number * 100 - # Protect from divide by zero - if nunique_markers == 0: - purity = pd.NA + max_min_cluster_size = 10000 + max_min_samples = 10 + min_cluster_size = 2 + min_samples = 1 + n_clusters = float("inf") + best_median = float("-inf") + best_df = pd.DataFrame() + cache_dir = tempfile.mkdtemp() + while n_clusters > 1: + binned_df = run_hdbscan( + table, + min_cluster_size=min_cluster_size, + min_samples=min_samples, + cache_dir=cache_dir, + ) + df, metrics_df = add_metrics(df=binned_df, markers_df=markers_df, domain=domain) + + completeness_filter = metrics_df["completeness"] >= completeness_cutoff + purity_filter = metrics_df["purity"] >= purity_cutoff + median_completeness = metrics_df[completeness_filter & purity_filter][ + "completeness" + ].median() + if median_completeness >= best_median: + best_median = median_completeness + best_df = df + + n_clusters = df["cluster"].nunique() + + if verbose: + logger.debug( + f"(min_samples, min_cluster_size): ({min_samples}, {min_cluster_size}) clusters: {n_clusters}" + f" median completeness (current, best): ({median_completeness:4.2f}, {best_median:4.2f})" + ) + + if min_cluster_size >= max_min_cluster_size: + shutil.rmtree(cache_dir) + cache_dir = tempfile.mkdtemp() + min_samples += 1 + min_cluster_size = 2 else: - purity = num_single_copy_markers / nunique_markers * 100 - metrics["completeness"].update({cluster: completeness}) - metrics["purity"].update({cluster: purity}) - metrics_df = pd.DataFrame(metrics, index=clusters.keys()) - return pd.merge(df, metrics_df, left_on="cluster", right_index=True) + min_cluster_size += 10 + + if metrics_df[completeness_filter & purity_filter].empty: + min_cluster_size += 100 + + if min_samples >= max_min_samples: + max_min_cluster_size *= 2 + + shutil.rmtree(cache_dir) + + if best_df.empty: + if verbose: + logger.debug("No complete or pure clusters found") + return pd.DataFrame(), table + + completeness_filter = best_df["completeness"] >= completeness_cutoff + purity_filter = best_df["purity"] >= purity_cutoff + complete_and_pure_df = best_df.loc[completeness_filter & purity_filter] + unclustered_df = best_df.loc[~best_df.index.isin(complete_and_pure_df.index)] + if verbose: + logger.debug(f"Best completeness median: {best_median:4.2f}") + logger.debug( + f"clustered: {len(complete_and_pure_df)} unclustered: {len(unclustered_df)}" + ) + return complete_and_pure_df, unclustered_df def get_clusters( @@ -261,7 +426,7 @@ def get_clusters( domain="bacteria", completeness=20.0, purity=90.0, - method="DBSCAN", + method="dbscan", verbose=False, ): """Find best clusters retained after applying `completeness` and `purity` filters. @@ -281,8 +446,8 @@ def get_clusters( purity : float `purity` threshold to retain cluster (the default is 90.). method : str - Description of parameter `method` (the default is 'DBSCAN'). - choices = ['DBSCAN','HDBSCAN'] + Description of parameter `method` (the default is 'dbscan'). + choices = ['dbscan','hdbscan'] verbose : bool log stats for each recursive_dbscan clustering iteration @@ -293,15 +458,16 @@ def get_clusters( """ num_clusters = 0 clusters = [] + recursive_clusterers = {"dbscan": recursive_dbscan, "hdbscan": recursive_hdbscan} + if method not in recursive_clusterers: + raise ValueError(f"Method: {method} not a choice. choose b/w dbscan & hdbscan") + clusterer = recursive_clusterers[method] + + # Continue while unclustered are remaining + # break when either clustered_df or unclustered_df is empty while True: - clustered_df, unclustered_df = recursive_dbscan( - master_df, - markers_df, - domain, - completeness, - purity, - method=method, - verbose=verbose, + clustered_df, unclustered_df = clusterer( + master_df, markers_df, domain, completeness, purity, verbose=verbose, ) # No contigs can be clustered, label as unclustered and add the final df # of (unclustered) contigs @@ -314,7 +480,10 @@ def get_clusters( c: f"bin_{1+i+num_clusters:04d}" for i, c in enumerate(clustered_df.cluster.unique()) } - rename_cluster = lambda c: translation[c] + + def rename_cluster(c): + return translation[c] + clustered_df.cluster = clustered_df.cluster.map(rename_cluster) # All contigs have now been clustered, add the final df of (clustered) contigs @@ -336,8 +505,9 @@ def binning( completeness=20.0, purity=90.0, taxonomy=True, - method="DBSCAN", - reverse=True, + starting_rank="superkingdom", + method="dbscan", + reverse_ranks=False, verbose=False, ): """Perform clustering of contigs by provided `method` and use metrics to @@ -353,23 +523,26 @@ def binning( i.e. [taxid,superkingdom,phylum,class,order,family,genus,species] markers : pd.DataFrame wide format, i.e. index=contig cols=[marker,marker,...] - domain : str + domain : str, optional Kingdom to determine metrics (the default is 'bacteria'). choices=['bacteria','archaea'] - completeness : float + completeness : float, optional Description of parameter `completeness` (the default is 20.). - purity : float + purity : float, optional Description of parameter `purity` (the default is 90.). - taxonomy : bool + taxonomy : bool, optional Split canonical ranks and subset based on rank then attempt to find clusters (the default is True). taxonomic_levels = [superkingdom,phylum,class,order,family,genus,species] - method : str - Clustering `method` (the default is 'DBSCAN'). - choices = ['DBSCAN','HDBSCAN'] - reverse : bool - True - [superkingdom,phylum,class,order,family,genus,species] - False - [species,genus,family,order,class,phylum,superkingdom] - verbose : bool + starting_rank : str, optional + Starting canonical rank at which to begin subsetting taxonomy (the default is superkingdom). + Choices are superkingdom, phylum, class, order, family, genus, species. + method : str, optional + Clustering `method` (the default is 'dbscan'). + choices = ['dbscan','hdbscan'] + reverse_ranks : bool, optional + False - [superkingdom,phylum,class,order,family,genus,species] (Default) + True - [species,genus,family,order,class,phylum,superkingdom] + verbose : bool, optional log stats for each recursive_dbscan clustering iteration Returns @@ -379,44 +552,53 @@ def binning( Raises ------- - RecursiveDBSCANError + BinningError No marker information is availble for contigs to be binned. """ # First check needs to ensure we have markers available to check binning quality... if master.loc[master.index.isin(markers.index)].empty: - err = "No markers for contigs in table. Unable to assess binning quality" - raise RecursiveDBSCANError(err) + raise BinningError( + "No markers for contigs in table. Unable to assess binning quality" + ) + logger.info(f"Using {method} clustering method") if not taxonomy: return get_clusters( - master, markers, domain, completeness, purity, method, verbose + master_df=master, + markers_df=markers, + domain=domain, + completeness=completeness, + purity=purity, + method=method, + verbose=verbose, ) # Use taxonomy method - if reverse: - # superkingdom, phylum, class, order, family, genus, species - ranks = [rank for rank in reversed(NCBI.CANONICAL_RANKS)] - else: + if reverse_ranks: # species, genus, family, order, class, phylum, superkingdom ranks = [rank for rank in NCBI.CANONICAL_RANKS] + else: + # superkingdom, phylum, class, order, family, genus, species + ranks = [rank for rank in reversed(NCBI.CANONICAL_RANKS)] ranks.remove("root") + starting_rank_index = ranks.index(starting_rank) + ranks = ranks[starting_rank_index:] + logger.debug(f"Using ranks: {', '.join(ranks)}") clustered_contigs = set() num_clusters = 0 clusters = [] for rank in ranks: # TODO: We should account for novel taxa here instead of removing 'unclassified' unclassified_filter = master[rank] != "unclassified" - n_contigs_in_taxa = ( - master.loc[unclassified_filter].groupby(rank)[rank].count().sum() - ) - n_taxa = ( - master.loc[unclassified_filter].groupby(rank)[rank].count().index.nunique() - ) + master_grouped_by_rank = master.loc[unclassified_filter].groupby(rank) + taxa_counts = master_grouped_by_rank[rank].count() + n_contigs_in_taxa = taxa_counts.sum() + n_taxa = taxa_counts.index.nunique() logger.info( f"Examining {rank}: {n_taxa:,} unique taxa ({n_contigs_in_taxa:,} contigs)" ) # Group contigs by rank and find best clusters within subset - for rank_name_txt, dff in master.loc[unclassified_filter].groupby(rank): + for rank_name_txt, dff in master_grouped_by_rank: if dff.empty: continue # Only cluster contigs that have not already been assigned a bin. @@ -428,14 +610,20 @@ def binning( if clustered_contigs: rank_df = rank_df.loc[~rank_df.index.isin(clustered_contigs)] # After all of the filters, are there multiple contigs to cluster? - if len(rank_df) <= 1: + if rank_df.empty: continue # Find best clusters logger.debug( f"Examining taxonomy: {rank} : {rank_name_txt} : {rank_df.shape}" ) clusters_df = get_clusters( - rank_df, markers, domain, completeness, purity, method + master_df=rank_df, + markers_df=markers, + domain=domain, + completeness=completeness, + purity=purity, + method=method, + verbose=verbose, ) # Store clustered contigs is_clustered = clusters_df["cluster"].notnull() @@ -447,7 +635,10 @@ def binning( c: f"bin_{1+i+num_clusters:04d}" for i, c in enumerate(clustered.cluster.unique()) } - rename_cluster = lambda c: translation[c] + + def rename_cluster(c): + return translation[c] + clustered.cluster = clustered.cluster.map(rename_cluster) num_clusters += clustered.cluster.nunique() clusters.append(clustered) @@ -467,7 +658,10 @@ def main(): level=logger.DEBUG, ) parser = argparse.ArgumentParser( - description="Perform decomposition/embedding/clustering via PCA/[TSNE|UMAP]/DBSCAN." + description="Perform marker gene guided binning of " + "metagenome contigs using annotations (when available) of sequence " + "composition, coverage and homology.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("kmers", help="") parser.add_argument("coverage", help="") @@ -477,20 +671,41 @@ def main(): parser.add_argument( "--embedding-method", help="Embedding method to use", - choices=["TSNE", "UMAP"], - default="TSNE", + choices=["bhsne", "sksne", "umap"], + default="bhsne", ) parser.add_argument( "--clustering-method", help="Clustering method to use", - choices=["DBSCAN", "HDBSCAN"], - default="DBSCAN", + choices=["dbscan", "hdbscan"], + default="dbscan", ) parser.add_argument( "--completeness", help="", default=20.0, type=float ) parser.add_argument("--purity", help="", default=90.0, type=float) parser.add_argument("--taxonomy", help="") + parser.add_argument( + "--starting-rank", + help="Canonical rank at which to begin subsetting taxonomy", + default="superkingdom", + choices=[ + "superkingdom", + "phylum", + "class", + "order", + "family", + "genus", + "species", + ], + ) + parser.add_argument( + "--reverse-ranks", + action="store_true", + default=False, + help="Reverse order at which to split taxonomy by canonical-rank." + " When --reverse-ranks given, contigs will be split in order of species, genus, family, order, class, phylum, superkingdom.", + ) parser.add_argument( "--domain", help="Kingdom to consider (archaea|bacteria)", @@ -508,7 +723,7 @@ def main(): cov_df = pd.read_csv(args.coverage, sep="\t", index_col="contig") master_df = pd.merge( - kmers_df, cov_df[["coverage"]], how="left", left_index=True, right_index=True + kmers_df, cov_df[["coverage"]], how="left", left_index=True, right_index=True, ) markers_df = Markers.load(args.markers) @@ -532,6 +747,8 @@ def main(): master=master_df, markers=markers_df, taxonomy=taxa_present, + starting_rank=args.starting_rank, + reverse_ranks=args.reverse_ranks, domain=args.domain, completeness=args.completeness, purity=args.purity, diff --git a/autometa/common/exceptions.py b/autometa/common/exceptions.py index a6fa86a8f..0e53579a3 100644 --- a/autometa/common/exceptions.py +++ b/autometa/common/exceptions.py @@ -67,8 +67,8 @@ def __str__(self): return self.value -class RecursiveDBSCANError(AutometaException): - """RecursiveDBSCANError exception class.""" +class BinningError(AutometaException): + """BinningError exception class.""" def __init__(self, value): self.value = value diff --git a/autometa/common/external/diamond.py b/autometa/common/external/diamond.py index 8330cc38d..9de1df1f6 100644 --- a/autometa/common/external/diamond.py +++ b/autometa/common/external/diamond.py @@ -29,6 +29,8 @@ import os import subprocess +import multiprocessing as mp + from itertools import chain from tqdm import tqdm @@ -39,38 +41,19 @@ class DiamondResult: - """DiamondResult class + """ + DiamondResult class to instantiate a DiamondResult object for each qseqid. + Allows for better handling of subject sequence IDs that hit to a qseqid. - Parameters - ---------- - qseqid : str - query sequence ID - sseqid : str - subject sequence ID - pident : float - Percentage of identical matches. - length : int - Alignment length. - mismatch : int - Number of mismatches. - gapopen : int - Number of gap openings. - qstart : int - Start of alignment in query. - qend : int - End of alignment in query. - sstart : int - Start of alignment in subject. - send : int - End of alignment in subject sequence. - evalue : float - Expect value. - bitscore : float - Bitscore. + Includes methods used to modify the DiamondResult object (add or remove sseqid from sseqids dictionary), + check if two DiamondResult objects have the same qseqid and return a user friendly and unambiguous output + from str() and repr() respectively. Also used to return the sseqid with the highest bitscore amongst all + the subject sequences that hit a query. These methods come in handy when retrieving diamond results from output table. Attributes ---------- sseqids : dict + All the subject sequences that hit to the query sequence {sseqid:parameters, sseqid:parameters, ...} qseqid: str result query sequence ID @@ -92,6 +75,36 @@ def __init__( evalue, bitscore, ): + """ + Instantiates the DiamondResult class + + Parameters + ---------- + qseqid : str + query sequence ID + sseqid : str + subject sequence ID + pident : float + Percentage of identical matches. + length : int + Alignment length. + mismatch : int + Number of mismatches. + gapopen : int + Number of gap openings. + qstart : int + Start of alignment in query. + qend : int + End of alignment in query. + sstart : int + Start of alignment in subject. + send : int + End of alignment in subject sequence. + evalue : float + Expect value. + bitscore : float + Bitscore. + """ self.qseqid = qseqid self.sseqids = { sseqid: { @@ -108,29 +121,110 @@ def __init__( } } - # def __repr__(self): - # return str(self) + def __repr__(self): + """ + Operator overloading to return the representation of the class object + + Returns + ------- + str + Class name followed by query sequence ID + """ + return f"Class: {self.__class__.__name__}, Query seqID: {self.qseqid}" def __str__(self): + """ + Operator overloading to return the string representation of the class objects + + Returns + ------- + str + String representation of query sequence ID, followed by total number of hits and finally the highest + bit score of all the hits + """ return f"{self.qseqid}; {len(self.sseqids)} sseqids; top hit by bitscore: {self.get_top_hit()}" def __eq__(self, other_hit): + """ + Operator overloading to compare two objects of the DiamondResult class + + Parameters + ---------- + other_hit : DiamondResult object + Other DiamondResult object to compare with + + Returns + ------- + Boolean + True if qseqid corresponding to both DiamondResult objects are equal, else False + """ if self.qseqid == other_hit.qseqid: return True else: return False def __add__(self, other_hit): - assert self == other_hit, f"qseqids do not match! {self} & {other_hit}" + """ + Operator overloading to update (add) the sseqids dictionary with the other sseqid hit dictionary. + The addition will only be successful if both the DiamondResult objects have the same qseqid + + Parameters + ---------- + other_hit : DiamondResult object + Other DiamondResult object to add + + Returns + ------- + DiamondResult object + DiamondResult object whose sseqids dict has been updated (added) with another sseqid + + Raises + ------ + AssertionError + Query sequences are not equal + """ + assert ( + self == other_hit + ), f"qseqids do not match! {self.qseqid} & {other_hit.qseqid}" self.sseqids.update(other_hit.sseqids) return self def __sub__(self, other_hit): - assert self == other_hit, f"qseqids do not match! {self} & {other_hit}" - self.sseqids.pop(other_hit.sseqid) + """ + Operator overloading to remove (subtract) the other sseqid hits dictionary from the sseqids dictionary. + The subtraction will only be successful if both the DiamondResult objects have the same qseqid + + Parameters + ---------- + other_hit : Dict + Other DiamondResult object to subtract + + Returns + ------- + DiamondResult object + DiamondResult object where a sseqid has been removed (subtracted) from the sseqids dict + + Raises + ------ + AssertionError + Query sequences are not equal + """ + assert ( + self == other_hit + ), f"qseqids do not match! {self.qseqid} & {other_hit.qseqid}" + for sseqid in other_hit.sseqids: + try: + self.sseqids.pop(sseqid) + except KeyError: + raise KeyError( + f"Given sseqid: {sseqid} is absent from the corresponding DiamondResult" + ) return self def get_top_hit(self): + """ + Returns the subject sequence ID with the highest bitscore amongst all the subject sequences that hit a query + """ top_bitscore = float("-Inf") top_hit = None for sseqid, attrs in self.sseqids.items(): @@ -140,13 +234,36 @@ def get_top_hit(self): return top_hit -def makedatabase(fasta, database, nproc=1): - cmd = f"diamond makedb --in {fasta} --db {database} -p {nproc}" - logger.debug(f"{cmd}") - with open(os.devnull, "w") as stdout, open(os.devnull, "w") as stderr: - retcode = subprocess.call(cmd, stdout=stdout, stderr=stderr, shell=True) - if retcode: - raise OSError(f"DiamondFailed:\nArgs:{proc.args}\nReturnCode:{proc.returncode}") +def makedatabase(fasta, database, cpus=mp.cpu_count()): + """ + Creates a database against which the query sequence would be blasted + + Parameters + ---------- + fasta : str + Path to fasta file whose database needs to be made + e.g. '' + database : str + Path to the output diamond formatted database file + e.g. '' + cpus : int, optional + Number of processors to be used. By default uses all the processors of the system + + Returns + ------- + str + Path to diamond formatted database + + Raises + ------ + subprocess.CalledProcessError + Failed to create diamond formatted database + """ + cmd = ["diamond", "makedb", "--in", fasta, "--db", database, "-p", str(cpus)] + logger.debug(" ".join(cmd)) + subprocess.run( + cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True + ) return database @@ -157,55 +274,57 @@ def blast( blast_type="blastp", evalue=float("1e-5"), maxtargetseqs=200, - cpus=0, - tmpdir=os.curdir, + cpus=mp.cpu_count(), + tmpdir=None, force=False, verbose=False, ): - """Performs diamond blastp search using fasta against diamond formatted database + """ + Performs diamond blastp search using query sequence against diamond formatted database Parameters ---------- fasta : str - . May be amino acid or nucleotide sequences + Path to fasta file having the query sequences. Should be amino acid sequences in case of BLASTP + and nucleotide sequences in case of BLASTX database : str - . + Path to diamond formatted database outfpath : str - . - blast_type : str - blastp if fasta consists of amino acids or blastx if nucleotides - evalue : float - cutoff e-value to count hit as significant (the default is float('1e-5')). - maxtargetseqs : int - max number of target sequences to retrieve per query by diamond (the default is 200). - cpus : int - number of cpus to use (the default is 1). - tmpdir : type - (the default is os.curdir). - force : boolean - overwrite existing diamond results `force` (the default is False). - verbose : boolean - log progress to terminal `verbose` (the default is False). + Path to output file + blast_type : str, optional + blastp to align protein query sequences against a protein reference database, + blastx to align translated DNA query sequences against a protein reference database, by default 'blastp' + evalue : float, optional + cutoff e-value to count hit as significant, by default float('1e-5') + maxtargetseqs : int, optional + max number of target sequences to retrieve per query by diamond, by default 200 + cpus : int, optional + Number of processors to be used, by default uses all the processors of the system + tmpdir : str, optional + Path to temporary directory. By default, same as the output directory + force : bool, optional + overwrite existing diamond results, by default False + verbose : bool, optional + log progress to terminal, by default False Returns ------- str - `outfpath` + Path to BLAST results Raises - ------- + ------ FileNotFoundError `fasta` file does not exist ValueError provided `blast_type` is not 'blastp' or 'blastx' - OSError - Diamond execution failed + subprocess.CalledProcessError + Failed to run blast """ if not os.path.exists(fasta): raise FileNotFoundError(fasta) if os.path.exists(outfpath) and not force: - empty = not os.stat(outfpath).st_size - if not empty: + if os.path.getsize(outfpath): if verbose: logger.warning(f"FileExistsError: {outfpath}. To overwrite use --force") return outfpath @@ -213,10 +332,10 @@ def blast( if blast_type not in ["blastp", "blastx"]: raise ValueError(f"blast_type must be blastp or blastx. Given: {blast_type}") if verbose: - logger.debug(f"Diamond{blast_type.title()} {fasta} against {database}") + logger.debug(f"diamond {blast_type} {fasta} against {database}") cmd = [ "diamond", - "blastp", + blast_type, "--query", fasta, "--db", @@ -231,28 +350,33 @@ def blast( "6", "--out", outfpath, - "--tmpdir", - tmpdir, ] + if tmpdir: + cmd.extend(["--tmpdir", tmpdir]) + # this is done as when cmd is a list each element should be a string cmd = [str(c) for c in cmd] if verbose: logger.debug(f'RunningDiamond: {" ".join(cmd)}') - with open(os.devnull, "w") as stdout, open(os.devnull, "w") as stderr: - proc = subprocess.run(cmd, stdout=stdout, stderr=stderr) - if proc.returncode: - raise OSError(f"DiamondFailed:\nArgs:{proc.args}\nReturnCode:{proc.returncode}") + subprocess.run( + cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True + ) return outfpath -def parse(results, top_pct=0.9, verbose=False): - """Retrieve diamond results from output table +def parse(results, bitscore_filter=0.9, verbose=False): + """ + Retrieve diamond results from output table Parameters ---------- results : str - - top_pct : 0 < float <= 1 - bitscore filter applied to each qseqid (the default is 0.9). + Path to BLASTP output file in outfmt9 + bitscore_filter : 0 < float <= 1, optional + Bitscore filter applied to each sseqid, by default 0.9 + Used to determine whether the bitscore is above a threshold value. + For example, if it is 0.9 then only bitscores >= 0.9 * the top bitscore are accepted + verbose : bool, optional + log progress to terminal, by default False Returns ------- @@ -264,7 +388,7 @@ def parse(results, top_pct=0.9, verbose=False): FileNotFoundError diamond results table does not exist ValueError - top_pct value is not a float or not in range of 0 to 1 + bitscore_filter value is not a float or not in range of 0 to 1 """ disable = False if verbose else True # boolean toggle --> keeping above vs. below because I think this is more readable. @@ -274,14 +398,14 @@ def parse(results, top_pct=0.9, verbose=False): if not os.path.exists(results): raise FileNotFoundError(results) try: - float(top_pct) - except ValueError as err: + float(bitscore_filter) + except ValueError: raise ValueError( - f"top_pct must be a float! Input: {top_pct} Type: {type(top_pct)}" + f"bitscore_filter must be a float! Input: {bitscore_filter} Type: {type(bitscore_filter)}" ) - in_range = 0.0 < top_pct <= 1.0 + in_range = 0.0 < bitscore_filter <= 1.0 if not in_range: - raise ValueError(f"top_pct not in range(0,1)! Input: {top_pct}") + raise ValueError(f"bitscore_filter not in range(0,1)! Input: {bitscore_filter}") hits = {} temp = set() n_lines = file_length(results) if verbose else None @@ -321,13 +445,17 @@ def parse(results, top_pct=0.9, verbose=False): topbitscore = bitscore temp = set([hit.qseqid]) continue - if bitscore >= top_pct * topbitscore: + if bitscore >= bitscore_filter * topbitscore: hits[hit.qseqid] += hit return hits def add_taxids(hits, database, verbose=True): - """Translates accessions to taxid translations from prot.accession2taxid.gz + """ + Translates accessions to taxids from prot.accession2taxid.gz. If an accession number is no + longer available in prot.accesssion2taxid.gz (either due to being suppressed, deprecated or + removed by NCBI), then None is returned as the taxid for the corresponsing sseqid. + # TODO: Should maybe write a wrapper for this to run on all of the NCBI databases listed below... Maybe this will help account for instances where the @@ -347,7 +475,9 @@ def add_taxids(hits, database, verbose=True): hits : dict {qseqid: DiamondResult, ...} database : str - + Path to prot.accession2taxid.gz database + verbose : bool, optional + log progress to terminal, by default False Returns ------- @@ -358,13 +488,6 @@ def add_taxids(hits, database, verbose=True): ------- FileNotFoundError prot.accession2taxid.gz database is required for translation taxid - DatabasesOutOfDateError - prot.accession2taxid.gz database and nr.dmnd are out of sync resulting - in accessions that are no longer available (either due to being - suppressed, deprecated or removed by NCBI). This must be resolved by - updating both nr and prot.accession2taxid.gz and re-running diamond on - the new nr.dmnd database. Alternatively, can try to find the exceptions in merged.dmp - # TODO: Replace file_length func for database file. (in this case 808,717,857 lines takes ~15 minutes simply to read each line...) @@ -379,18 +502,17 @@ def add_taxids(hits, database, verbose=True): accessions = set( chain.from_iterable([hit.sseqids.keys() for qseqid, hit in hits.items()]) ) - fh = gzip.open(database) if database.endswith(".gz") else open(database) - __ = fh.readline() + # "rt" open the database in text mode instead of binary. Now it can be handled like a text file + fh = gzip.open(database, "rt") if database.endswith(".gz") else open(database) + __ = fh.readline() # remove the first line as it just gives the description if verbose: logger.debug( f"Searching for {len(accessions):,} accessions in {os.path.basename(database)}. This may take a while..." ) - is_gzipped = True if database.endswith(".gz") else False n_lines = file_length(database) if verbose else None desc = f"Parsing {os.path.basename(database)}" acc2taxids = {} for line in tqdm(fh, disable=disable, desc=desc, total=n_lines, leave=False): - line = line.decode() if is_gzipped else line acc_num, acc_ver, taxid, _ = line.split("\t") taxid = int(taxid) if acc_num in accessions: @@ -405,48 +527,23 @@ def add_taxids(hits, database, verbose=True): ): for sseqid in hit.sseqids: taxid = acc2taxids.get(sseqid) - # if taxid is None: - # raise DatabasesOutOfDateError(f'{sseqid} deprecated/suppressed/removed') hit.sseqids[sseqid].update({"taxid": taxid}) return hits -def main(args): - result = blast( - fasta=args.fasta, - database=args.database, - outfpath=args.outfile, - blast_type=args.blast_type, - evalue=args.evalue, - maxtargetseqs=args.maxtargetseqs, - cpus=args.cpus, - tmpdir=args.tmpdir, - force=args.force, - verbose=args.verbose, - ) - hits = parse(results=result, top_pct=args.top_pct, verbose=args.verbose) - hits = add_taxids(hits=hits, database=args.acc2taxids, verbose=args.verbose) - fname, __ = os.path.splitext(os.path.basename(args.outfile)) - dirpath = os.path.dirname(os.path.realpath(args.outfile)) - hits_fname = ".".join([fname, "pkl.gz"]) - hits_fpath = os.path.join(dirpath, hits_fname) - pickled_fpath = make_pickle(obj=hits, outfpath=hits_fpath) - logger.debug(f"{len(hits):,} diamond hits serialized to {pickled_fpath}") - - -if __name__ == "__main__": +def main(): import argparse - import os parser = argparse.ArgumentParser( description=""" Retrieves blastp hits with provided input assembly - """ + """, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("fasta", help="") - parser.add_argument("database", help="") - parser.add_argument("acc2taxids", help="") - parser.add_argument("outfile", help="") + parser.add_argument("fasta", help="Path to fasta file having the query sequences") + parser.add_argument("database", help="Path to diamond formatted database") + parser.add_argument("acc2taxids", help="Path to prot.accession2taxid.gz database") + parser.add_argument("outfile", help="Path to output file") parser.add_argument( "blast_type", help="[blastp]: A.A -> A.A. [blastx]: Nucl. -> A.A.", @@ -462,14 +559,46 @@ def main(args): default=200, type=int, ) - parser.add_argument("--cpus", help="num cpus to use", default=0, type=int) - parser.add_argument("--tmpdir", help="", default=os.curdir) parser.add_argument( - "--top-pct", help="top percentage of hits to retrieve", default=0.9 + "--cpus", help="number of processors to use", default=mp.cpu_count(), type=int + ) + parser.add_argument( + "--tmpdir", + help="Path to directory which will be used for temporary storage by diamond", + ) + parser.add_argument( + "--bitscore-filter", + help="hits with bitscore greater than bitscr-filter*top bitscore will be kept", + default=0.9, ) parser.add_argument( "--force", help="force overwrite of diamond output table", action="store_true" ) parser.add_argument("--verbose", help="add verbosity", action="store_true") args = parser.parse_args() - main(args) + result = blast( + fasta=args.fasta, + database=args.database, + outfpath=args.outfile, + blast_type=args.blast_type, + evalue=args.evalue, + maxtargetseqs=args.maxtargetseqs, + cpus=args.cpus, + tmpdir=args.tmpdir, + force=args.force, + verbose=args.verbose, + ) + hits = parse( + results=result, bitscore_filter=args.bitscore_filter, verbose=args.verbose + ) + hits = add_taxids(hits=hits, database=args.acc2taxids, verbose=args.verbose) + fname, __ = os.path.splitext(os.path.basename(args.outfile)) + dirpath = os.path.dirname(os.path.realpath(args.outfile)) + hits_fname = ".".join([fname, "pkl.gz"]) + hits_fpath = os.path.join(dirpath, hits_fname) + pickled_fpath = make_pickle(obj=hits, outfpath=hits_fpath) + logger.debug(f"{len(hits):,} diamond hits serialized to {pickled_fpath}") + + +if __name__ == "__main__": + main() diff --git a/autometa/common/utilities.py b/autometa/common/utilities.py index e7951b19f..dd0d7f72d 100644 --- a/autometa/common/utilities.py +++ b/autometa/common/utilities.py @@ -29,9 +29,12 @@ import logging import os import pickle +import sys import tarfile import time +import numpy as np + from functools import wraps @@ -101,8 +104,8 @@ def make_pickle(obj, outfpath): return outfpath -def gunzip(infpath, outfpath): - """Decompress gzipped `infpath` to `outfpath`. +def gunzip(infpath, outfpath, delete_original=False, block_size=65536): + """Decompress gzipped `infpath` to `outfpath` and write checksum of `outfpath` upon successful decompression. Parameters ---------- @@ -110,6 +113,10 @@ def gunzip(infpath, outfpath): outfpath : str + delete_original : bool + Will delete the original file after successfully decompressing `infpath` (Default is False). + block_size : int + Amount of `infpath` to read in to memory before writing to `outfpath` (Default is 65536 bytes). Returns ------- @@ -125,15 +132,21 @@ def gunzip(infpath, outfpath): logger.debug( f"gunzipping {os.path.basename(infpath)} to {os.path.basename(outfpath)}" ) - if os.path.exists(outfpath) and os.stat(outfpath).st_size > 0: + if os.path.exists(outfpath) and os.path.getsize(outfpath) > 0: raise FileExistsError(outfpath) lines = "" - with gzip.open(infpath) as fh: - for line in fh: - lines += line.decode() - with open(outfpath, "w") as out: + with gzip.open(infpath, "rt") as fh, open(outfpath, "w") as out: + for i, line in enumerate(fh): + lines += line + if sys.getsizeof(lines) >= block_size: + out.write(lines) + lines = "" out.write(lines) logger.debug(f"gunzipped {infpath} to {outfpath}") + write_checksum(outfpath, f"{outfpath}.md5") + if delete_original: + os.remove(infpath) + logger.debug(f"removed original file: {infpath}") return outfpath @@ -149,7 +162,7 @@ def untar(tarchive, outdir, member=None): outdir : str - member : str + member : str, optional member file to extract. Returns @@ -165,6 +178,7 @@ def untar(tarchive, outdir, member=None): `tarchive` is not a tar archive KeyError `member` was not found in `tarchive` + """ if not member and not outdir: raise ValueError( @@ -172,7 +186,7 @@ def untar(tarchive, outdir, member=None): ) logger.debug(f"decompressing tarchive {tarchive} to {outdir}") outfpath = os.path.join(outdir, member) if member else None - if member and os.path.exists(outfpath) and os.stat(outfpath).st_size > 0: + if member and os.path.exists(outfpath) and os.path.getsize(outfpath) > 0: raise FileExistsError(outfpath) if not tarfile.is_tarfile(tarchive): raise ValueError(f"{tarchive} is not a tar archive") @@ -219,6 +233,7 @@ def tarchive_results(outfpath, src_dirpath): ------- FileExistsError `outfpath` already exists + """ logger.debug(f"tar archiving {src_dirpath} to {outfpath}") if os.path.exists(outfpath): @@ -229,16 +244,17 @@ def tarchive_results(outfpath, src_dirpath): return outfpath -def file_length(fpath): +def file_length(fpath, approximate=False): """Retrieve the number of lines in `fpath` - See: - https://stackoverflow.com/questions/845058/how-to-get-line-count-of-a-large-file-cheaply-in-python + See: https://stackoverflow.com/q/845058/13118765 Parameters ---------- fpath : str Description of parameter `fpath`. + approximate: bool + If True, will approximate the length of the file from the file size. Returns ------- @@ -253,18 +269,28 @@ def file_length(fpath): """ if not os.path.exists(fpath): raise FileNotFoundError(fpath) - if fpath.endswith(".gz"): - fh = gzip.open(fpath, "rb") - else: - fh = open(fpath, "rb") + + fh = gzip.open(fpath, "rt") if fpath.endswith(".gz") else open(fpath, "rb") + if approximate: + lines = [] + n_sample_lines = 100000 + for i, l in enumerate(fh): + if i > n_sample_lines: + break + lines.append(sys.getsizeof(l)) + fh.close() + avg_size_per_line = np.average(lines) + total_size = os.path.getsize(fpath) + return int(np.ceil(total_size / avg_size_per_line)) + for i, l in enumerate(fh): pass fh.close() return i + 1 -def get_checksum(fpath): - """Retrieve sha256 checksums from provided `args`. +def calc_checksum(fpath): + """Retrieve md5 checksum from provided `fpath`. See: https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file @@ -277,7 +303,8 @@ def get_checksum(fpath): Returns ------- str - hexdigest of `fpath` using sha256 + space-delimited hexdigest of `fpath` using md5sum and basename of `fpath`. + e.g. 'hash filename\n' Raises ------- @@ -285,10 +312,11 @@ def get_checksum(fpath): Provided `fpath` does not exist TypeError `fpath` is not a string + """ - def sha(block): - hasher = hashlib.sha256() + def md5sum(block): + hasher = hashlib.md5() for bytes in block: hasher.update(bytes) return hasher.hexdigest() @@ -300,14 +328,78 @@ def blockiter(fh, blocksize=65536): yield block block = fh.read(blocksize) - if type(fpath) != str: + if not isinstance(fpath, str): raise TypeError(type(fpath)) if not os.path.exists(fpath): raise FileNotFoundError(fpath) fh = open(fpath, "rb") - cksum = sha(blockiter(fh)) + hash = md5sum(blockiter(fh)) fh.close() - return cksum + return f"{hash} {os.path.basename(fpath)}\n" + + +def read_checksum(fpath): + """Read checksum from provided checksum formatted `fpath`. + + Note: See `write_checksum` for how a checksum file is generated. + + Parameters + ---------- + fpath : str + + + Returns + ------- + str + checksum retrieved from `fpath`. + + Raises + ------- + TypeError + Provided `fpath` was not a string. + FileNotFoundError + Provided `fpath` does not exist. + + """ + if not isinstance(fpath, str): + raise TypeError(type(fpath)) + if not os.path.exists(fpath): + raise FileNotFoundError(fpath) + with open(fpath) as fh: + return fh.readline() + + +def write_checksum(infpath, outfpath): + """Calculate checksum for `infpath` and write to `outfpath`. + + Parameters + ---------- + infpath : str + + outfpath : str + + + Returns + ------- + NoneType + Description of returned object. + + Raises + ------- + FileNotFoundError + Provided `infpath` does not exist + TypeError + `infpath` or `outfpath` is not a string + + """ + if not os.path.exists(infpath): + raise FileNotFoundError(infpath) + if not isinstance(outfpath, str): + raise TypeError(type(outfpath)) + checksum = calc_checksum(infpath) + with open(outfpath, "w") as fh: + fh.write(checksum) + logger.debug(f"Wrote {infpath} checksum to {outfpath}") def valid_checkpoint(checkpoint_fp, fpath): @@ -331,9 +423,10 @@ def valid_checkpoint(checkpoint_fp, fpath): Either `fpath` or `checkpoint_fp` does not exist TypeError Either `fpath` or `checkpoint_fp` is not a string + """ for fp in [checkpoint_fp, fpath]: - if not type(fp) is str: + if not isinstance(fp, str): raise TypeError(f"{fp} is type: {type(fp)}") if not os.path.exists(fp): raise FileNotFoundError(fp) @@ -345,8 +438,8 @@ def valid_checkpoint(checkpoint_fp, fpath): # If filepaths never match, prev_chksum and new_chksum will not match. # Giving expected result. break - new_chksum = get_checksum(fpath) - return True if new_chksum == prev_chksum else False + new_chksum = calc_checksum(fpath) + return new_chksum == prev_chksum def get_checkpoints(checkpoint_fp, fpaths=None): @@ -358,7 +451,7 @@ def get_checkpoints(checkpoint_fp, fpaths=None): ---------- checkpoint_fp : str - fpaths : [str, ...] + fpaths : [str, ...], optional [, ...] Returns @@ -371,6 +464,7 @@ def get_checkpoints(checkpoint_fp, fpaths=None): ValueError When `checkpoint_fp` first being written, will not populate an empty checkpoints file. Raises an error if the `fpaths` list is empty or None + """ if not os.path.exists(checkpoint_fp): logger.debug(f"{checkpoint_fp} not found... Writing") @@ -381,10 +475,10 @@ def get_checkpoints(checkpoint_fp, fpaths=None): outlines = "" for fpath in fpaths: try: - checksum = get_checksum(fpath) + checksum = calc_checksum(fpath) except FileNotFoundError as err: checksum = "" - outlines += f"{checksum}\t{fpath}\n" + outlines += checksum with open(checkpoint_fp, "w") as fh: fh.write(outlines) logger.debug(f"Written: {checkpoint_fp}") @@ -412,20 +506,19 @@ def update_checkpoints(checkpoint_fp, fpath): ------- dict {fp:checksum, ...} + """ checkpoints = get_checkpoints(checkpoint_fp) if valid_checkpoint(checkpoint_fp, fpath): return checkpoints - new_checksum = get_checksum(fpath) + new_checksum = calc_checksum(fpath) checkpoints.update({fpath: new_checksum}) outlines = "" for fp, chk in checkpoints.items(): outlines += f"{chk}\t{fp}\n" with open(checkpoint_fp, "w") as fh: fh.write(outlines) - logger.debug( - f"Updated checkpoints with {os.path.basename(fpath)} -> {new_checksum[:16]}" - ) + logger.debug(f"Checkpoints updated: {new_checksum[:16]} {os.path.basename(fpath)}") return checkpoints @@ -469,13 +562,9 @@ def wrapper(*args, **kwds): if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="file containing utilities functions for Autometa pipeline" + print( + "This file contains utilities for Autometa pipeline and should not be run directly!" ) - print("file containing utilities functions for Autometa pipeline") - args = parser.parse_args() import sys - sys.exit(1) + sys.exit(0) diff --git a/autometa/config/databases.py b/autometa/config/databases.py index 19d65d805..2cd676baa 100644 --- a/autometa/config/databases.py +++ b/autometa/config/databases.py @@ -20,19 +20,21 @@ along with Autometa. If not, see . COPYRIGHT -Configuration handling for Autometa Databases. +This file contains the Databases class responsible for configuration handling +of Autometa Databases. """ import logging import os import requests +import tempfile import multiprocessing as mp from configparser import ConfigParser -from configparser import ExtendedInterpolation from ftplib import FTP +from glob import glob from autometa.config import get_config from autometa.config import DEFAULT_FPATH @@ -40,6 +42,9 @@ from autometa.config import put_config from autometa.config import AUTOMETA_DIR from autometa.common.utilities import untar +from autometa.common.utilities import calc_checksum +from autometa.common.utilities import read_checksum +from autometa.common.utilities import write_checksum from autometa.common.external import diamond from autometa.common.external import hmmer @@ -50,10 +55,38 @@ class Databases: - """docstring for Databases.""" + """Database class containing methods to allow downloading/formatting/updating + Autometa database dependencies. + + Parameters + ---------- + config : config.ConfigParser + Config containing database dependency information. + (the default is DEFAULT_CONFIG). + + dryrun : bool + Run through database checking without performing + downloads/formatting (the default is False). + + nproc : int + Number of processors to use to perform database formatting. + (the default is mp.cpu_count()). + + upgrade : bool + Overwrite existing databases with more up-to-date database + files. (the default is False). + + Attributes + ---------- + ncbi_dir : str markers_dir : str + SECTIONS : dict keys are `sections` + respective to database config sections and values are options within + the `sections`. + + """ SECTIONS = { - "ncbi": ["nodes", "names", "merged", "accession2taxid", "nr",], + "ncbi": ["nodes", "names", "merged", "accession2taxid", "nr"], "markers": [ "bacteria_single_copy", "bacteria_single_copy_cutoffs", @@ -62,31 +95,51 @@ class Databases: ], } - def __init__(self, config=DEFAULT_CONFIG, dryrun=False, nproc=mp.cpu_count()): - if type(config) is not ConfigParser: - raise TypeError(f"config is not ConfigParser : {type(config)}") - if type(dryrun) is not bool: - raise TypeError(f"dryrun must be True or False. type: {type(dryrun)}") - - self.config = config - self.dryrun = dryrun - self.nproc = nproc - self.prepare_sections() - self.ncbi_dir = self.config.get("databases", "ncbi") - self.markers_dir = self.config.get("databases", "markers") + def __init__( + self, config=DEFAULT_CONFIG, dryrun=False, nproc=mp.cpu_count(), upgrade=False, + ): + """ - @property - def satisfied(self): - return self.get_missing(validate=True) + At instantiation of Databases instance, if any of the respective + database directories do not exist, they will be created. This will be + reflected in the provided `config`. - def prepare_sections(self): - """Add database sections to 'databases' if missing. + Parameters + ---------- + config : config.ConfigParser + Config containing database dependency information. + (the default is DEFAULT_CONFIG). + dryrun : bool + Run through database checking without performing downloads/ + formatting (the default is False). + + nproc : int + Number of processors to use to perform database formatting. + (the default is mp.cpu_count()). + upgrade : bool + Overwrite existing databases with more up-to-date database files. + (the default is False). Returns ------- - NoneType + databases.Databases: Databases object + instance of Databases class """ + if not isinstance(config, ConfigParser): + raise TypeError(f"config is not ConfigParser : {type(config)}") + if not isinstance(dryrun, bool): + raise TypeError(f"dryrun must be boolean. type: {type(dryrun)}") + + self.config = config + self.dryrun = dryrun + self.nproc = nproc + self.upgrade = upgrade + if self.config.get("common", "home_dir") == "None": + # neccessary if user not running databases through the user + # endpoint. where :func:`~autometa.config.init_default` would've + # been called. + self.config.set("common", "home_dir", AUTOMETA_DIR) if not self.config.has_section("databases"): self.config.add_section("databases") for section in Databases.SECTIONS: @@ -94,9 +147,94 @@ def prepare_sections(self): continue outdir = DEFAULT_CONFIG.get("databases", section) self.config.set("databases", section, outdir) + self.ncbi_dir = self.config.get("databases", "ncbi") + self.markers_dir = self.config.get("databases", "markers") + for outdir in {self.ncbi_dir, self.markers_dir}: + if not os.path.exists(outdir): + os.makedirs(outdir) + + def satisfied(self, section=None, compare_checksums=False): + """Determines whether all database dependencies are satisfied. + + Parameters + ---------- + section : str + section to retrieve for `checksums` section. + Choices include: 'ncbi' and 'markers'. + compare_checksums : bool, optional + Also check if database information is up-to-date with current + hosted databases. (default is False). + + Returns + ------- + bool + True if all database dependencies are satisfied, otherwise False. + + """ + any_missing = self.get_missing(section=section) + if compare_checksums: + any_invalid = self.compare_checksums(section=section) + else: + any_invalid = {} + return not any_missing and not any_invalid + + def get_remote_checksum(self, section, option): + """Get the checksum from provided `section` respective to `option` in + `self.config`. + + Parameters + ---------- + section : str + section to retrieve for `checksums` section. + Choices include: 'ncbi' and 'markers'. + option : str + `option` in `checksums` section corresponding to the section + checksum file. + + Returns + ------- + str + checksum of remote md5 file. e.g. 'hash filename\n' + + Raises + ------- + ConnectionError + Failed to connect to host for provided `option`. + + """ + if section not in {"ncbi", "markers"}: + raise ValueError( + f'"section" must be "ncbi" or "markers". Provided: {section}' + ) + if section == "ncbi": + host = self.config.get(section, "host") + ftp_fullpath = self.config.get("checksums", option) + chksum_fpath = ftp_fullpath.split(host)[-1] + with FTP(host) as ftp, tempfile.TemporaryFile() as fp: + ftp.login() + result = ftp.retrbinary(f"RETR {chksum_fpath}", fp.write) + if not result.startswith("226 Transfer complete"): + raise ConnectionError(f"{chksum_fpath} download failed") + ftp.quit() + fp.seek(0) + checksum = fp.read().decode() + elif section == "markers": + url = self.config.get("checksums", option) + with requests.Session() as session: + resp = session.get(url) + if not resp.ok: + raise ConnectionError(f"Failed to retrieve {url}") + checksum = resp.text + return checksum def format_nr(self): - """Format NCBI nr.gz database into diamond formatted database nr.dmnd. + """Construct a diamond formatted database (nr.dmnd) from `nr` option + in `ncbi` section in user config. + + NOTE: The checksum 'nr.dmnd.md5' will only be generated if nr.dmnd + construction is successful. If the provided `nr` option in `ncbi` is + 'nr.gz' the database will be removed after successful database + formatting. Returns ------- @@ -105,23 +243,82 @@ def format_nr(self): """ db_infpath = self.config.get("ncbi", "nr") + db_infpath_md5 = f"{db_infpath}.md5" db_outfpath = db_infpath.replace(".gz", ".dmnd") - if not self.dryrun and not os.path.exists(db_outfpath): - diamond.makedatabase(fasta=nr, database=db_infpath, nproc=self.nproc) + + db_outfpath_exists = os.path.exists(db_outfpath) + if db_outfpath_exists: + db_outfpath_hash, __ = calc_checksum(db_outfpath).split() + + remote_checksum_matches = False + current_nr_checksum_matches = False + # Check database and database checksum is up-to-date + if os.path.exists(db_infpath_md5) and db_outfpath_exists: + # Check if the current db md5 is up-to-date with the remote db md5 + current_hash, __ = read_checksum(db_infpath_md5).split() + remote_hash, __ = self.get_remote_checksum("ncbi", "nr").split() + if remote_hash == current_hash: + remote_checksum_matches = True + # Check if the current db md5 matches the calc'd db checksum + if db_outfpath_hash == current_hash: + current_nr_checksum_matches = True + + db_outfpath_md5 = f"{db_outfpath}.md5" + db_outfpath_md5_checksum_matches = False + if os.path.exists(db_outfpath_md5) and db_outfpath_exists: + db_outfpath_md5_hash, __ = read_checksum(db_outfpath_md5).split() + if db_outfpath_hash == db_outfpath_md5_hash: + db_outfpath_md5_checksum_matches = True + + checksum_checks = ["nr.dmnd.md5", "nr.gz.md5", "remote nr.gz.md5"] + checksum_matches = [ + db_outfpath_md5_checksum_matches, + current_nr_checksum_matches, + remote_checksum_matches, + ] + for checksum_match, checksum_check in zip(checksum_matches, checksum_checks): + # If the checksums do not match, we need to update the database file. + if checksum_match: + logger.debug(f"{checksum_check} checksum matches, skipping...") + self.config.set("ncbi", "nr", db_outfpath) + logger.debug(f"set ncbi nr: {db_outfpath}") + return + # Only update out-of-date db files if user wants to update via self.upgrade + if not self.upgrade and checksum_check == "remote nr.gz.md5": + return + + diamond.makedatabase(fasta=db_infpath, database=db_outfpath, nproc=self.nproc) + # Write checksum for nr.dmnd + write_checksum(db_outfpath, db_outfpath_md5) + + if os.path.basename(db_infpath) == "nr.gz": + # nr.gz will be removed after successful nr.dmnd construction + os.remove(db_infpath) + self.config.set("ncbi", "nr", db_outfpath) logger.debug(f"set ncbi nr: {db_outfpath}") def extract_taxdump(self): - """Extract autometa required files from ncbi taxdump directory. + """Extract autometa required files from ncbi taxdump.tar.gz archive + into ncbi databases directory and update user config with extracted + paths. - Extracts nodes.dmp, names.dmp and merged.dmp from taxdump.tar.gz + This only extracts nodes.dmp, names.dmp and merged.dmp from + taxdump.tar.gz if the files do not already exist. If `upgrade` + was originally supplied as `True` to the Databases instance, then the + previous files will be replaced by the new taxdump files. + + After successful extraction of the files, a checksum will be written + of the archive for future checking. Returns ------- NoneType + Will update `self.config` section `ncbi` with options 'nodes', + 'names','merged' """ - taxdump = self.config.get("ncbi", "taxdump") + taxdump_fpath = self.config.get("ncbi", "taxdump") taxdump_files = [ ("nodes", "nodes.dmp"), ("names", "names.dmp"), @@ -129,67 +326,121 @@ def extract_taxdump(self): ] for option, fname in taxdump_files: outfpath = os.path.join(self.ncbi_dir, fname) - if not self.dryrun and not os.path.exists(outfpath): - outfpath = untar(taxdump, self.ncbi_dir, fname) + if self.dryrun: + logger.debug(f"UPDATE (ncbi,{option}): {outfpath}") + self.config.set("ncbi", option, outfpath) + continue + # Only update the taxdump files if the user says to do an update. + if self.upgrade and os.path.exists(outfpath): + os.remove(outfpath) + # Only extract the taxdump files if this is not a "dryrun" + if not os.path.exists(outfpath): + outfpath = untar(taxdump_fpath, self.ncbi_dir, fname) + write_checksum(outfpath, f"{outfpath}.md5") + logger.debug(f"UPDATE (ncbi,{option}): {outfpath}") self.config.set("ncbi", option, outfpath) - def update_ncbi(self, options): - """Update NCBI database files (taxdump.tar.gz and nr.gz). + def download_ncbi_files(self, options): + """Download NCBI database files. Parameters ---------- - options : set - Set of options to update + options : iterable + iterable containing options in 'ncbi' section to download. Returns ------- NoneType - Description of returned object. + Will update provided `options` in `self.config`. Raises ------- + subprocess.CalledProcessError + NCBI file download with rsync failed. ConnectionError - NCBI file download failed. + NCBI file checksums do not match after file transfer. """ - # Download required NCBI database files - if not os.path.exists(self.ncbi_dir): - os.makedirs(self.ncbi_dir) - host = DEFAULT_CONFIG.get("ncbi", "host") + host = self.config.get("ncbi", "host") for option in options: - ftp_fullpath = DEFAULT_CONFIG.get("database_urls", option) - ftp_fpath = ftp_fullpath.split(host)[-1] - if self.config.has_option("ncbi", option): + ftp_fullpath = self.config.get("database_urls", option) + + if ( + self.config.has_option("ncbi", option) + and self.config.get("ncbi", option) is not None + ): outfpath = self.config.get("ncbi", option) else: - outfname = os.path.basename(ftp_fpath) + outfname = os.path.basename(ftp_fullpath) outfpath = os.path.join(self.ncbi_dir, outfname) - if not self.dryrun and not os.path.exists(outfpath): - with FTP(host) as ftp, open(outfpath, "wb") as fp: - ftp.login() - logger.debug(f"starting {option} download") - result = ftp.retrbinary(f"RETR {ftp_fpath}", fp.write) - if not result.startswith("226 Transfer complete"): - raise ConnectionError(f"{option} download failed") - ftp.quit() + logger.debug(f"UPDATE: (ncbi,{option}): {outfpath}") self.config.set("ncbi", option, outfpath) - # Extract/format respective NCBI files - self.extract_taxdump() - self.format_nr() - def update_markers(self, options): - """Update single-copy markers hmms and cutoffs. + if self.dryrun: + return + + rsync_fpath = ftp_fullpath.replace("ftp", "rsync") + cmd = ["rsync", "--quiet", "--archive", rsync_fpath, outfpath] + logger.debug(f"starting {option} download") + subprocess.run( + cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True + ) + checksum_outfpath = f"{outfpath}.md5" + write_checksum(outfpath, checksum_outfpath) + current_checksum = read_checksum(checksum_outfpath) + current_hash, __ = current_checksum.split() + remote_checksum = self.get_remote_checksum("ncbi", option) + remote_hash, __ = remote_checksum.split() + if current_checksum != remote_hash: + raise ConnectionError(f"{option} download failed") + if "taxdump" in options: + self.extract_taxdump() + if "nr" in options: + self.format_nr() + + def press_hmms(self): + """hmmpress markers hmm database files. + + Returns + ------- + NoneType + + """ + hmm_search_str = os.path.join(self.markers_dir, "*.h3?") + # First search for pressed hmms to remove from list to hmmpress + pressed_hmms = { + os.path.realpath(os.path.splitext(fp)[0]) + for fp in glob(hmm_search_str) + if not fp.endswith(".md5") + } + # Now retrieve all hmms in markers directory + hmms = ( + os.path.join(self.markers_dir, fn) + for fn in os.listdir(self.markers_dir) + if fn.endswith(".hmm") + ) + # Filter by hmms not already pressed + hmms = (fpath for fpath in hmms if fpath not in pressed_hmms) + # Press hmms and write checksums of their indices + for hmm_fp in hmms: + hmmer.hmmpress(hmm_fp) + for index_fp in glob(f"{hmm_fp}.h3?"): + write_checksum(index_fp, f"{index_fp}.md5") + + def download_markers(self, options): + """Download markers database files and amend user config to reflect this. Parameters ---------- - options : set - Description of parameter `options`. + options : iterable + iterable containing options in 'markers' section to download. Returns ------- NoneType + Will update provided `options` in `self.config`. Raises ------- @@ -197,105 +448,239 @@ def update_markers(self, options): marker file download failed. """ - if not os.path.exists(self.markers_dir): - os.makedirs(self.markers_dir) for option in options: - url = DEFAULT_CONFIG.get("database_urls", option) + # First retrieve the markers file url from `option` in `markers` + url = self.config.get("database_urls", option) if self.config.has_option("markers", option): outfpath = self.config.get("markers", option) else: outfname = os.path.basename(url) outfpath = os.path.join(self.markers_dir, outfname) + if self.dryrun: logger.debug(f"UPDATE: (markers,{option}): {outfpath}") self.config.set("markers", option, outfpath) continue - with requests.Session() as session: + + # Retrieve markers file and write contents to `outfpath` + with requests.Session() as session, open(outfpath, "w") as fh: resp = session.get(url) - if not resp.ok: - raise ConnectionError(f"Failed to retrieve {url}") - with open(outfpath, "w") as outfh: - outfh.write(resp.text) + if not resp.ok: + raise ConnectionError(f"Failed to retrieve {url}") + fh.write(resp.text) self.config.set("markers", option, outfpath) - if outfpath.endswith(".hmm"): - hmmer.hmmpress(outfpath) + checksum_outfpath = f"{outfpath}.md5" + write_checksum(outfpath, checksum_outfpath) + current_checksum = read_checksum(checksum_outfpath) + current_hash, __ = current_checksum.split() + remote_checksum = self.get_remote_checksum("markers", option) + remote_hash, __ = remote_checksum.split() + if current_checksum != remote_hash: + raise ConnectionError(f"{option} download failed") + self.press_hmms() + + def get_missing(self, section=None): + """Get all missing database files in `options` from `sections` + in config. + + Parameters + ---------- + section : str, optional + Configure provided `section`. Choices include 'markers' and 'ncbi'. + (default will download/format all database directories) + + Returns + ------- + dict + {section:{option, option,...}, section:{...}, ...} - def get_missing(self, validate=False): - """Retrieve all database files from all database sections that are not - available. + """ + sections = [section] if section else Databases.SECTIONS.keys() + missing = {} + for section in sections: + for option in self.config.options(section): + # Skip user added options not required by Autometa + if option not in Databases.SECTIONS.get(section): + continue + fpath = self.config.get(section, option) + if os.path.exists(fpath): + continue + if section in missing: + missing[section].add(option) + else: + missing.update({section: set([option])}) + # Log missing options + for section, options in missing.items(): + for option in options: + logger.debug(f"MISSING: ({section},{option})") + return missing + + def download_missing(self, section=None): + """Download missing Autometa database dependencies from provided `section`. + If no `section` is provided will check all sections. + + Parameters + ---------- + section : str, optional + Section to check for missing database files (the default is None). + Choices include 'ncbi' and 'markers'. Returns ------- - bool or dict + NoneType + Will update provided `section` in `self.config`. - - if `validate` is True : bool + Raises + ------- + ValueError + Provided `section` does not match 'ncbi' and 'markers'. - all available evaluates to True, otherwise False + """ + dispatcher = { + "ncbi": self.download_ncbi_files, + "markers": self.download_markers, + } + if section and section not in dispatcher: + raise ValueError(f'{section} does not match "ncbi" or "markers"') + if section: + missing = self.get_missing(section=section) + options = missing.get(section, []) + dispatcher[section](options) + else: + missing = self.get_missing() + for section, options in missing.items(): + dispatcher[section](options) - - if `validate` is False : dict + def compare_checksums(self, section=None): + """Get all invalid database files in `options` from `section` + in config. An md5 checksum comparison will be performed between the + current and file's remote md5 to ensure file integrity prior to + checking the respective file as valid. - {section:{option, option,...}, section:{...}, ...} + Parameters + ---------- + section : str, optional Configure provided `section` Choices include + 'markers' and 'ncbi'. (default will download/format all database + directories) + + Returns + ------- + dict {section:{option, option,...}, section:{...}, ...} """ - missing = {} - for section in Databases.SECTIONS: + sections = [section] if section else Databases.SECTIONS.keys() + invalid = {} + taxdump_checked = False + for section in sections: for option in self.config.options(section): if option not in Databases.SECTIONS.get(section): # Skip user added options not required by Autometa continue + # nodes.dmp, names.dmp and merged.dmp are all in taxdump.tar.gz + option = "taxdump" if option in {"nodes", "names", "merged"} else option fpath = self.config.get(section, option) - if os.path.exists(fpath) and os.stat(fpath).st_size >= 0: - # TODO: [Checkpoint validation] - logger.debug(f"({section},{option}): {fpath}") + fpath_md5 = f"{fpath}.md5" + # We can not checksum a file that does not exist. + if not os.path.exists(fpath) and not os.path.exists(fpath_md5): continue - if validate: - return False - if section in missing: - missing[section].add(option) + # To not waste time checking the taxdump files 3 times. + if option == "taxdump" and taxdump_checked: + continue + if os.path.exists(fpath_md5): + current_checksum = read_checksum(fpath_md5) else: - missing.update({section: set([option])}) - for section, opts in missing.items(): - for opt in opts: - logger.debug(f"MISSING: ({section},{opt})") - return True if validate else missing + current_checksum = calc_checksum(fpath) + current_hash, __ = current_checksum.split() + try: + remote_checksum = self.get_remote_checksum(section, option) + remote_hash, __ = remote_checksum.split() + except ConnectionError as err: + # Do not mark file as invalid if a connection error occurs. + logger.warning(err) + continue + if option == "taxdump": + taxdump_checked = True + if remote_hash == current_hash: + logger.debug(f"{option} checksums match, skipping...") + continue + if section in invalid: + invalid[section].add(option) + else: + invalid.update({section: set([option])}) + # Log invalid options + for section, options in invalid.items(): + for option in options: + logger.debug(f"INVALID: ({section},{option})") + return invalid - def update_missing(self): - """Download and format databases for all options in each section. + def fix_invalid_checksums(self, section=None): + """Download/Update/Format databases where checksums are out-of-date. - NOTE: This will only perform the download and formatting if self.dryrun is False + Parameters + ---------- + section : str, optional + Configure provided `section`. Choices include 'markers' and 'ncbi'. + (default will download/format all database directories) Returns ------- NoneType - config updated with required missing sections. + Will update provided `options` in `self.config`. + + Raises + ------- + ConnectionError + Failed to connect to `section` host site. """ - dispatcher = {"ncbi": self.update_ncbi, "markers": self.update_markers} - missing = self.get_missing() - for section, options in missing.items(): - if section == "ncbi": - if "nodes" in options or "names" in options or "merged" in options: - options.discard("nodes") - options.discard("names") - options.discard("merged") - options.add("taxdump") + dispatcher = { + "ncbi": self.download_ncbi_files, + "markers": self.download_markers, + } + if section and section not in dispatcher: + raise ValueError(f'{section} does not match "ncbi" or "markers"') + if section: + invalid = self.compare_checksums(section=section) + options = invalid.get(section, set()) dispatcher[section](options) + else: + invalid = self.compare_checksums() + for section, options in invalid.items(): + dispatcher[section](options) - def configure(self): - """Checks database files + def configure(self, section=None, no_checksum=False): + """Configures Autometa's database dependencies by first checking missing + dependencies then comparing checksums to ensure integrity of files. + + Download and format databases for all options in each section. + + This will only perform the download and formatting if `self.dryrun` is + False. This will update out-of-date databases if `self.upgrade` is + True. + + Parameters + ---------- + section : str, optional Configure provided `section`. Choices include + 'markers' and 'ncbi'. (default will download/format all database + directories) no_checksum : bool, optional Do not perform checksum + comparisons (Default is False). Returns ------- - configparser.ConfigParser - config with updated options in respective databases sections. + configparser.ConfigParser config with updated options in respective + databases sections. Raises ------- - ExceptionName - Why the exception is raised. + ValueError Provided `section` does not match 'ncbi' or 'markers'. + ConnectionError A connection issue occurred when connecting to NCBI + or GitHub. """ - self.update_missing() + self.download_missing(section=section) + if no_checksum: + return self.config + self.fix_invalid_checksums(section=section) return self.config @@ -303,7 +688,6 @@ def main(): import argparse import logging as logger - cpus = mp.cpu_count() logger.basicConfig( format="[%(asctime)s %(levelname)s] %(name)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", @@ -311,11 +695,13 @@ def main(): ) parser = argparse.ArgumentParser( - description="databases config", - epilog="By default, with no arguments, will download/format databases into default databases directory.", + description="Main script to configure Autometa database dependencies.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + epilog="By default, with no arguments, will download/format databases " + "into default databases directory.", ) parser.add_argument( - "--config", help="", default=DEFAULT_FPATH + "--config", help="", default=DEFAULT_FPATH, ) parser.add_argument( "--dryrun", @@ -323,27 +709,81 @@ def main(): action="store_true", default=False, ) + parser.add_argument( + "--update", + help="Update all out-of-date databases.", + action="store_true", + default=False, + ) + parser.add_argument( + "--update-markers", + help="Update out-of-date markers databases.", + action="store_true", + default=False, + ) + parser.add_argument( + "--update-ncbi", + help="Update out-of-date ncbi databases.", + action="store_true", + default=False, + ) + parser.add_argument( + "--check-dependencies", + help="Check database dependencies are satisfied.", + action="store_true", + default=False, + ) + parser.add_argument( + "--no-checksum", + help="Do not perform remote checksum comparisons to validate databases" + "are up-to-date.", + action="store_true", + default=False, + ) parser.add_argument( "--nproc", - help=f"num. cpus to use for DB formatting. (default {cpus})", + help="num. cpus to use for DB formatting.", type=int, - default=cpus, + default=mp.cpu_count(), ) parser.add_argument("--out", help="") args = parser.parse_args() config = get_config(args.config) - dbs = Databases(config=config, dryrun=args.dryrun, nproc=args.nproc) - logger.debug(f"Configuring databases") - config = dbs.configure() - dbs = Databases(config=config, dryrun=args.dryrun, nproc=args.nproc) - logger.info(f"Database dependencies satisfied: {dbs.satisfied}") + dbs = Databases( + config=config, dryrun=args.dryrun, nproc=args.nproc, upgrade=args.update, + ) + + compare_checksums = False + for update_section in [args.update, args.update_markers, args.update_ncbi]: + if update_section and not args.no_checksum: + compare_checksums = True + + if args.update_markers: + section = "markers" + elif args.update_ncbi: + section = "ncbi" + else: + section = None + + if args.check_dependencies: + dbs_satisfied = dbs.satisfied( + section=section, compare_checksums=compare_checksums + ) + logger.info(f"Database dependencies satisfied: {dbs_satisfied}") + import sys + + sys.exit(0) + + # logger.debug(f'Configuring databases') + config = dbs.configure(section=section, no_checksum=args.no_checksum) + if not args.out: import sys sys.exit(0) put_config(config, args.out) - logger.debug(f"{args.out} written.") + logger.info(f"{args.out} written.") if __name__ == "__main__": diff --git a/autometa/config/default.config b/autometa/config/default.config index 0b4a1ac38..7475f595c 100644 --- a/autometa/config/default.config +++ b/autometa/config/default.config @@ -56,15 +56,19 @@ markers = ${databases:base}/markers taxdump = ftp://${ncbi:host}/pub/taxonomy/taxdump.tar.gz accession2taxid = ftp://${ncbi:host}/pub/taxonomy/accession2taxid/prot.accession2taxid.gz nr = ftp://${ncbi:host}/blast/db/FASTA/nr.gz -bacteria_single_copy = https://github.com/WiscEvan/Autometa/raw/dev/databases/markers/bacteria.single_copy.hmm -bacteria_single_copy_cutoffs = https://${markers:host}/WiscEvan/Autometa/dev/databases/markers/bacteria.single_copy.cutoffs?token=AGF3KQVL3J4STDT4TJQVDBS6GG5FE -archaea_single_copy = https://github.com/WiscEvan/Autometa/raw/dev/databases/markers/archaea.single_copy.hmm -archaea_single_copy_cutoffs = https://${markers:host}/WiscEvan/Autometa/dev/databases/markers/archaea.single_copy.cutoffs?token=AGF3KQXVUDFIH6ECVTYMZQS6GG5KO +bacteria_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.hmm +bacteria_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.cutoffs +archaea_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.hmm +archaea_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.cutoffs [checksums] taxdump = ftp://${ncbi:host}/pub/taxonomy/taxdump.tar.gz.md5 accession2taxid = ftp://${ncbi:host}/pub/taxonomy/accession2taxid/prot.accession2taxid.gz.md5 nr = ftp://${ncbi:host}/blast/db/FASTA/nr.gz.md5 +bacteria_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.hmm.md5 +bacteria_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/bacteria.single_copy.cutoffs.md5 +archaea_single_copy = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.hmm.md5 +archaea_single_copy_cutoffs = https://${markers:host}/KwanLab/Autometa/dev/autometa/databases/markers/archaea.single_copy.cutoffs.md5 [ncbi] host = ftp.ncbi.nlm.nih.gov @@ -98,8 +102,8 @@ bed = alignments.bed length_filtered = metagenome.filtered.fna coverages = coverages.tsv kmer_counts = kmers.tsv -kmer_normalized = kmers.normalized.tsv -kmer_embedded = kmers.embedded.tsv +kmer_normalized = kmers.normalized.tsv +kmer_embedded = kmers.embedded.tsv nucleotide_orfs = metagenome.filtered.orfs.fna amino_acid_orfs = metagenome.filtered.orfs.faa blastp = blastp.tsv diff --git a/autometa/config/user.py b/autometa/config/user.py index 77483f72d..caa599333 100644 --- a/autometa/config/user.py +++ b/autometa/config/user.py @@ -34,7 +34,7 @@ from autometa.common import utilities from autometa.common.metagenome import Metagenome -from autometa.common.mag import MAG +from autometa.common.metabin import MetaBin from autometa.config.databases import Databases from autometa.config.project import Project from autometa.common.utilities import timeit diff --git a/autometa/databases/markers/archaea.single_copy.cutoffs.md5 b/autometa/databases/markers/archaea.single_copy.cutoffs.md5 new file mode 100644 index 000000000..0be5450fe --- /dev/null +++ b/autometa/databases/markers/archaea.single_copy.cutoffs.md5 @@ -0,0 +1 @@ +d0121bd751e0d454a81c54a0655ce745 archaea.single_copy.cutoffs diff --git a/autometa/databases/markers/archaea.single_copy.hmm.md5 b/autometa/databases/markers/archaea.single_copy.hmm.md5 new file mode 100644 index 000000000..2a966fe72 --- /dev/null +++ b/autometa/databases/markers/archaea.single_copy.hmm.md5 @@ -0,0 +1 @@ +fb0f644a2df855fa7145436564091e07 archaea.single_copy.hmm diff --git a/autometa/databases/markers/bacteria.single_copy.cutoffs.md5 b/autometa/databases/markers/bacteria.single_copy.cutoffs.md5 new file mode 100644 index 000000000..fe76ec24c --- /dev/null +++ b/autometa/databases/markers/bacteria.single_copy.cutoffs.md5 @@ -0,0 +1 @@ +936e6cb46002caa08b3690b8446086fb bacteria.single_copy.cutoffs diff --git a/autometa/databases/markers/bacteria.single_copy.hmm.md5 b/autometa/databases/markers/bacteria.single_copy.hmm.md5 new file mode 100644 index 000000000..44b324798 --- /dev/null +++ b/autometa/databases/markers/bacteria.single_copy.hmm.md5 @@ -0,0 +1 @@ +eafdb5bd1447e814a4ee47ee11434ffe bacteria.single_copy.hmm diff --git a/docs/parse_argparse.py b/docs/parse_argparse.py index 08c5f1124..ab47b8974 100644 --- a/docs/parse_argparse.py +++ b/docs/parse_argparse.py @@ -103,6 +103,11 @@ def get_usage(argparse_lines): ------- wrapped_lines : string indented arparse output after running the `--help` command + + Raises + ------ + subprocess.CalledProcessError + Error while running --help on these argparse lines """ __, tmp_fpath = tempfile.mkstemp() with open(tmp_fpath, "w") as outfh: diff --git a/docs/source/conf.py b/docs/source/conf.py index 8d3f5c97f..bd9d48d07 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,8 @@ autodoc_mock_imports = ["Bio", "hdbscan", "tsne", "sklearn", "umap", "tqdm"] -import parse_argparse # nopep8 +# fmt: off +import parse_argparse # -- Project information ----------------------------------------------------- @@ -48,7 +49,8 @@ ] todo_include_todos = True - +# Includes doctrings of functions that begin with double underscore +napoleon_include_special_with_doc = True autosummary_generate = True # Add any paths that contain templates here, relative to this directory.