Skip to content

Commit

Permalink
Merge pull request #101 from johnlees/sample_qc
Browse files Browse the repository at this point in the history
Sample QC
  • Loading branch information
johnlees committed Aug 26, 2020
2 parents dcfca15 + 6e49f13 commit aa27558
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 187 deletions.
101 changes: 74 additions & 27 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,23 @@ def get_options():

# qc options
qcGroup = parser.add_argument_group('Quality control options')
qcGroup.add_argument('--max-a-dist', default = 0.5, type=float, help='Maximum accessory distance to permit '
'[default = 0.5]')
qcGroup.add_argument('--ignore-length', help='Ignore outliers in terms of assembly length '
'[default = False]', default=False, action='store_true')
qcGroup.add_argument('--estimated-length', default=2000000, type = int, help='Provide an integer estimated genome length when using "--ignore-length" [default = 2000000]')

qcGroup.add_argument('--qc-filter', help='Behaviour following sequence QC step: "stop" [default], "prune"'
'(analyse data passing QC), or "continue" (analyse all data)',
default='stop', type = str, choices=['stop', 'prune', 'continue'])
qcGroup.add_argument('--retain-failures', help='Retain sketches of genomes that do not pass QC filters in '
'separate database [default = False]', default=False, action='store_true')
qcGroup.add_argument('--max-a-dist', help='Maximum accessory distance to permit [default = 0.5]',
default = 0.5, type = float)
qcGroup.add_argument('--length-sigma', help='Number of standard deviations of length distribution beyond '
'which sequences will be excluded [default = 5]', default = 5, type = int)
qcGroup.add_argument('--length-range', help='Allowed length range, outside of which sequences will be excluded '
'[two values needed - lower and upper bounds]', default=[None,None],
type = int, nargs = 2)
qcGroup.add_argument('--prop-n', help='Threshold ambiguous base proportion above which sequences will be excluded'
' [default = 0.1]', default = 0.1,
type = float)
qcGroup.add_argument('--upper-n', help='Threshold ambiguous base count above which sequences will be excluded',
default=None, type = int)

# model fitting
modelGroup = parser.add_argument_group('Model fit options')
Expand Down Expand Up @@ -266,6 +277,16 @@ def main():
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']

# Dict of QC options for passing to database construction and querying functions
qc_dict = {
'qc_filter': args.qc_filter,
'retain_failures': args.retain_failures,
'length_sigma': args.length_sigma,
'length_range': args.length_range,
'prop_n': args.prop_n,
'upper_n': args.upper_n
}

# define sketch sizes, store in hash in case one day
# different kmers get different hash sizes
sketch_sizes = {}
Expand All @@ -279,6 +300,13 @@ def main():
if not args.use_mash:
sketch_sizes = int(round(max(sketch_sizes.values())/64))

# if a length range is specified, check it makes sense
if args.length_range[0] is not None:
if args.length_range[0] >= args.length_range[1]:
sys.stderr.write('Ensure the specified length range is space-separated argument of'
' length 2, with the lower value first\n')
sys.exit(1)

# check if working with lineages
rank_list = []
if args.lineage_clustering or args.assign_lineages:
Expand Down Expand Up @@ -327,18 +355,24 @@ def main():
elif args.easy_run:
sys.stderr.write("Mode: Creating clusters from assemblies (create_db & fit_model)\n")
if args.r_files is not None:
# Sketch
# generate sketches and QC sequences
createDatabaseDir(args.output, kmers)
constructDatabase(args.r_files, kmers, sketch_sizes, args.output, args.estimated_length, args.ignore_length, args.threads,
args.overwrite)
seq_names = constructDatabase(args.r_files, kmers, sketch_sizes,
args.output,
args.threads,
args.overwrite,
strand_preserved = args.strand_preserved,
min_count = args.min_kmer_count,
use_exact = args.exact_count,
qc_dict = qc_dict)

# Calculate and QC distances
if args.use_mash == True:
rNames = None
qNames = readRfile(args.r_files, oneSeq=True)[1]
qNames = seq_names
else:
rNames = readRfile(args.r_files)[0]
qNames = rNames
rNames = seq_names
qNames = seq_names
refList, queryList, distMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = args.output,
Expand Down Expand Up @@ -588,7 +622,7 @@ def main():
# Read and overwrite previous database
kmers, sketch_sizes = readDBParams(ref_db, kmers, sketch_sizes)
constructDatabase(dummyRefFile, kmers, sketch_sizes, args.output,
args.estimated_length, True, args.threads, True) # overwrite old db
True, args.threads, True) # overwrite old db
os.remove(dummyRefFile)

genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt')
Expand All @@ -601,11 +635,13 @@ def main():
#*******************************#
elif args.assign_query or args.assign_lineages:
assign_query(dbFuncs, args.ref_db, args.q_files, args.output, args.update_db, args.full_db, args.distances,
args.microreact, args.cytoscape, kmers, sketch_sizes, args.ignore_length, args.estimated_length,
args.microreact, args.cytoscape, kmers, sketch_sizes,
args.threads, args.use_mash, args.mash, args.overwrite, args.plot_fit, args.no_stream,
args.max_a_dist, args.model_dir, args.previous_clustering, args.external_clustering,
args.core_only, args.accessory_only, args.phandango, args.grapetree, args.info_csv,
args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory)
args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory,
strand_preserved = args.strand_preserved, min_count = args.min_kmer_count,
use_exact = args.exact_count, qc_dict = qc_dict)

#******************************#
#* *#
Expand Down Expand Up @@ -757,10 +793,12 @@ def main():
#* *#
#*******************************#
def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances, microreact, cytoscape,
kmers, sketch_sizes, ignore_length, estimated_length, threads, use_mash, mash, overwrite,
kmers, sketch_sizes, threads, use_mash, mash, overwrite,
plot_fit, no_stream, max_a_dist, model_dir, previous_clustering,
external_clustering, core_only, accessory_only, phandango, grapetree,
info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory):
info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory,
# added extra arguments for constructing sketchlib libraries
strand_preserved, min_count, use_exact, qc_dict):
"""Code for assign query mode. Written as a separate function so it can be called
by pathogen.watch API
"""
Expand Down Expand Up @@ -792,23 +830,30 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances

# Sketch query sequences
createDatabaseDir(output, kmers)
constructDatabase(q_files, kmers, sketch_sizes, output,
estimated_length, ignore_length, threads, overwrite)

# Find distances vs ref seqs
rNames = []
if use_mash == True:
rNames = None
qNames = readRfile(q_files, oneSeq=True)[1]
# construct database and QC
qNames = constructDatabase(q_files, kmers, sketch_sizes, output,
threads, overwrite)
else:
qNames = readRfile(q_files)[0]
if os.path.isfile(ref_db + "/" + os.path.basename(ref_db) + ".refs"):
with open(ref_db + "/" + os.path.basename(ref_db) + ".refs") as refFile:
for reference in refFile:
rNames.append(reference.rstrip())
else:
rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5")

# construct database and QC
qNames = constructDatabase(q_files, kmers, sketch_sizes, output,
threads, overwrite,
strand_preserved = strand_preserved,
min_count = min_count,
use_exact = use_exact,
qc_dict = qc_dict)

# run query
refList, queryList, distMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = ref_db,
Expand All @@ -817,6 +862,8 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
self = False,
number_plot_fits = plot_fit,
threads = threads)

# QC distance matrix
qcPass = qcDistMat(distMat, refList, queryList, max_a_dist)

# Calculate query-query distances
Expand All @@ -826,8 +873,8 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
if assign_lineage:

# Assign lineages by calculating query-query information
ordered_queryList, query_distMat = calculateQueryQueryDistances(dbFuncs, refList, q_files,
kmers, estimated_length, output, use_mash, threads)
ordered_queryList, query_distMat = calculateQueryQueryDistances(dbFuncs, refList, qNames,
kmers, output, use_mash, threads)

else:
# Assign these distances as within or between strain
Expand All @@ -854,8 +901,8 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
core_only, accessory_only)

# Assign clustering by adding to network
ordered_queryList, query_distMat = addQueryToNetwork(dbFuncs, refList, q_files,
genomeNetwork, kmers, estimated_length, queryAssignments, model, output, update_db,
ordered_queryList, query_distMat = addQueryToNetwork(dbFuncs, refList, queryList, q_files,
genomeNetwork, kmers, queryAssignments, model, output, update_db,
use_mash, threads)

# if running simple query
Expand Down Expand Up @@ -890,7 +937,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
if newQueries != queryList and use_mash:
tmpRefFile = writeTmpFile(newQueries)
constructDatabase(tmpRefFile, kmers, sketch_sizes, output,
estimated_length, True, threads, True) # overwrite old db
True, threads, True) # overwrite old db
os.remove(tmpRefFile)
# With mash, this is the reduced DB constructed,
# with sketchlib, all sketches
Expand Down
21 changes: 5 additions & 16 deletions PopPUNK/mash.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
sys.exit(0)

from .utils import iterDistRows
from .utils import assembly_qc
from .utils import readRfile

from .plot import plot_fit
Expand Down Expand Up @@ -292,8 +291,7 @@ def joinDBs(db1, db2, output, klist, mash_exec = 'mash'):


def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
estimated_length, ignoreLengthOutliers = False,
threads = 1, overwrite = False, reads = False,
threads = 1, overwrite = False,
mash_exec = 'mash'):
"""Sketch the input assemblies at the requested k-mer lengths
Expand All @@ -313,24 +311,12 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
Size of sketch (``-s`` option)
oPrefix (str)
Output prefix for resulting sketch files
estimated_length (int)
Estimated length of genome, if not calculated from data
ignoreLengthOutliers (bool)
Whether to check for outlying genome lengths (and error
if found)
(default = False)
threads (int)
Number of threads to use
(default = 1)
overwrite (bool)
Whether to overwrite sketch DBs, if they already exist.
(default = False)
reads (bool)
If reads are being used as input
(default = False)
mash_exec (str)
Location of mash executable
Expand All @@ -342,7 +328,7 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
raise NotImplementedError("Cannot use reads with mash backend")

names, sequences = readRfile(assemblyList, oneSeq=True)
genome_length, max_prob = assembly_qc(sequences, klist, ignoreLengthOutliers, estimated_length)
genome_length, max_prob = assembly_qc(sequences, klist)

# create kmer databases
if threads > len(klist):
Expand All @@ -363,6 +349,9 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
pool.map(partial(runSketch, assemblyList=sequenceFile.name, sketch=sketch_size,
genome_length=genome_length,oPrefix=oPrefix, mash_exec=mash_exec,
overwrite=overwrite, threads=num_threads), klist)

# return sequence names
return sequences

def init_lock(l):
"""Sets a global lock to use when writing to STDERR in :func:`~runSketch`"""
Expand Down
32 changes: 19 additions & 13 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def networkSummary(G):

return(components, density, transitivity, score)

def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
assignments, model, queryDB, queryQuery = False,
use_mash = False, threads = 1):
"""Finds edges between queries and items in the reference database,
Expand All @@ -326,14 +326,14 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs`
rlist (list)
List of reference names
qfile (str)
File containing queries
qList (list)
List of query names
qFile (list)
File of query sequences
G (graph)
Network to add to (mutated)
kmers (list)
List of k-mer sizes
estimated_length (int)
Estimated length of genome, if not calculated from data
assignments (numpy.array)
Cluster assignment of items in qlist
model (ClusterModel)
Expand All @@ -342,13 +342,11 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
Query database location
queryQuery (bool)
Add in all query-query distances
(default = False)
use_mash (bool)
Use the mash backend
no_stream (bool)
Don't stream mash output
(default = False)
threads (int)
Number of threads to use if new db created
Expand All @@ -375,16 +373,24 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
distMat = None

# Set up query names
qList, qSeqs = readRfile(qfile, oneSeq = use_mash)
if use_mash == True:
# mash must use sequence file names for both testing for
# assignment and for generating a new database
rNames = None
qNames = isolateNameToLabel(qSeqs)
qNames = qList
else:
rNames = qList
qNames = rNames
queryFiles = dict(zip(qNames, qSeqs))

# identify query sequence files
qSeqs = []
queryFiles = {}
with open(qFile, 'r') as qfile:
for line in qfile.readlines():
info = line.rstrip().split()
if info[0] in qNames:
qSeqs.append(info[1])
queryFiles[info[0]] = info[1]

# store links for each query in a list of edge tuples
ref_count = len(rlist)
Expand All @@ -399,9 +405,8 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
sys.stderr.write("Calculating all query-query distances\n")
qlist1, distMat = calculateQueryQueryDistances(dbFuncs,
rNames,
qfile,
qNames,
kmers,
estimated_length,
queryDB,
use_mash,
threads)
Expand All @@ -413,6 +418,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,

# Otherwise only calculate query-query distances for new clusters
else:

# identify potentially new lineages in list: unassigned is a list of queries with no hits
unassigned = set(qSeqs).difference(assigned)
query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)}
Expand All @@ -435,7 +441,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,

# use database construction methods to find links between unassigned queries
sketchSize = readDBParams(queryDB, kmers, None)[1]
constructDatabase(tmpFile, kmers, sketchSize, tmpDirName, estimated_length, True, threads, False)
constructDatabase(tmpFile, kmers, sketchSize, tmpDirName, True, threads, False)

qlist1, qlist2, distMat = queryDatabase(rNames = list(unassigned),
qNames = list(unassigned),
Expand Down

0 comments on commit aa27558

Please sign in to comment.