Skip to content

Commit

Permalink
Web (#124)
Browse files Browse the repository at this point in the history
Add web API
  • Loading branch information
Danderson123 committed Nov 25, 2020
1 parent ea9e4cb commit 634394c
Show file tree
Hide file tree
Showing 9 changed files with 426 additions and 77 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def main():
genomeNetwork = indivNetworks[min(rank_list)]

# Ensure all in dists are in final network
networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
networkMissing = set(map(str,set(range(len(refList))).difference(list(genomeNetwork.vertices()))))
if len(networkMissing) > 0:
missing_isolates = [refList[m] for m in networkMissing]
sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n")
Expand Down
42 changes: 24 additions & 18 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def assign_query(dbFuncs,
previous_clustering,
external_clustering,
core_only,
accessory_only):
"""Code for assign query mode. Written as a separate function so it can
be called by web APIs
"""
accessory_only,
web,
json_sketch):
"""Code for assign query mode. Written as a separate function so it can be called
by web APIs"""

# Modules imported here as graph tool is very slow to load (it pulls in all of GTK?)
from .models import loadClusterFit, ClusterFit, BGMMFit, DBSCANFit, RefineFit, LineageFit

Expand All @@ -66,6 +68,8 @@ def assign_query(dbFuncs,
from .utils import update_distance_matrices
from .utils import createOverallLineage

from .web import sketch_to_hdf5

createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
joinDBs = dbFuncs['joinDBs']
Expand Down Expand Up @@ -111,18 +115,20 @@ def assign_query(dbFuncs,
rNames.append(reference.rstrip())
else:
rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5")

# construct database
createDatabaseDir(output, kmers)
qNames = constructDatabase(q_files,
kmers,
sketch_sizes,
output,
threads,
overwrite,
codon_phased = codon_phased,
calc_random = False)

if (web and json_sketch):
qNames = sketch_to_hdf5(json_sketch, output)
else:
# construct database
createDatabaseDir(output, kmers)
qNames = constructDatabase(q_files,
kmers,
sketch_sizes,
output,
threads,
overwrite,
codon_phased = codon_phased,
calc_random = False)
# run query
refList, queryList, qrDistMat = queryDatabase(rNames = rNames,
qNames = qNames,
Expand All @@ -132,7 +138,6 @@ def assign_query(dbFuncs,
self = False,
number_plot_fits = plot_fit,
threads = threads)

# QC distance matrix
qcPass = qcDistMat(qrDistMat, refList, queryList, max_a_dist)

Expand Down Expand Up @@ -348,7 +353,6 @@ def get_options():
other.add_argument('--gpu-sketch', default=False, action='store_true', help='Use a GPU when calculating sketches (read data only) [default = False]')
other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]')
other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]')

other.add_argument('--version', action='version',
version='%(prog)s '+__version__)

Expand Down Expand Up @@ -420,7 +424,9 @@ def main():
args.previous_clustering,
args.external_clustering,
args.core_only,
args.accessory_only)
args.accessory_only,
web = False,
json_sketch = None)

sys.stderr.write("\nDone\n")

Expand Down
2 changes: 1 addition & 1 deletion PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def fetchNetwork(network_dir, model, refList, ref_graph = False,
sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n")

# Ensure all in dists are in final network
networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
networkMissing = set(map(str,set(range(len(refList))).difference(list(genomeNetwork.vertices()))))
if len(networkMissing) > 0:
sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")

Expand Down
3 changes: 2 additions & 1 deletion PopPUNK/sketchlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def joinDBs(db1, db2, output):
# Can only copy into new group, so for second file these are appended one at a time
try:
hdf1.copy('sketches', hdf_join)
hdf1.copy('random', hdf_join)
if 'random' in hdf1:
hdf1.copy('random', hdf_join)
join_grp = hdf_join['sketches']
read_grp = hdf2['sketches']
for dataset in read_grp:
Expand Down
168 changes: 113 additions & 55 deletions PopPUNK/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def get_options():
default=False,
action='store_true')

# query options
queryingGroup = parser.add_argument_group('Database querying options')
queryingGroup.add_argument('--core-only', help='(with a \'refine\' model) '
'Use a core-distance only model for assigning queries '
'[default = False]', default=False, action='store_true')
queryingGroup.add_argument('--accessory-only', help='(with a \'refine\' or \'lineage\' model) '
'Use an accessory-distance only model for assigning queries '
'[default = False]', default=False, action='store_true')

# plot output
faGroup = parser.add_argument_group('Visualisation options')
faGroup.add_argument('--microreact', help='Generate output files for microreact visualisation', default=False, action='store_true')
Expand Down Expand Up @@ -128,35 +137,54 @@ def get_options():

return args

def main():
"""Main function. Parses cmd line args and runs in the specified mode.
"""
args = get_options()
def generate_visualisations(query_db,
ref_db,
distances,
threads,
output,
gpu_dist,
deviceid,
external_clustering,
microreact,
phandango,
grapetree,
cytoscape,
perplexity,
strand_preserved,
include_files,
model_dir,
previous_clustering,
previous_query_clustering,
info_csv,
rapidnj,
overwrite,
core_only,
accessory_only):

# Check on parallelisation of graph-tools
setGtThreads(args.threads)
setGtThreads(threads)

sys.stderr.write("PopPUNK: visualise\n")
if not (args.microreact or args.phandango or args.grapetree or args.cytoscape):
if not (microreact or phandango or grapetree or cytoscape):
sys.stderr.write("Must specify at least one type of visualisation to output\n")
sys.exit(1)

# make directory for new output files
if not os.path.isdir(args.output):
if not os.path.isdir(output):
try:
os.makedirs(args.output)
os.makedirs(output)
except OSError:
sys.stderr.write("Cannot create output directory\n")
sys.exit(1)

# Load original distances
if args.distances is None:
if args.query_db is None:
distances = os.path.basename(args.ref_db) + "/" + args.ref_db + ".dists"
if distances is None:
if query_db is None:
distances = os.path.basename(ref_db) + "/" + ref_db + ".dists"
else:
distances = os.path.basename(args.query_db) + "/" + args.query_db + ".dists"
distances = os.path.basename(query_db) + "/" + query_db + ".dists"
else:
distances = args.distances
distances = distances

rlist, qlist, self, complete_distMat = readPickle(distances)
if not self:
Expand All @@ -169,33 +197,31 @@ def main():
sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n"
"Note: Distance will be extended to full all-vs-all distances\n"
"Note: Re-run poppunk_assign with --update-db to avoid this\n")

ref_db = os.path.basename(args.ref_db) + "/" + args.ref_db
query_db = os.path.basename(args.query_db) + "/" + args.query_db
ref_db = os.path.basename(ref_db) + "/" + ref_db
rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db + ".dists")
if not self_ref:
sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n")
sys.exit(1)

kmers, sketch_sizes, codon_phased = readDBParams(args.query_db)
addRandom(args.query_db, qlist, kmers,
strand_preserved = args.strand_preserved, threads = args.threads)
kmers, sketch_sizes, codon_phased = readDBParams(query_db)
addRandom(query_db, qlist, kmers,
strand_preserved = strand_preserved, threads = threads)
query_db = os.path.basename(query_db) + "/" + query_db
qq_distMat = pp_sketchlib.queryDatabase(query_db, query_db,
qlist, qlist, kmers,
True, False,
args.threads,
args.gpu_dist,
args.deviceid)
threads,
gpu_dist,
deviceid)

# If the assignment was run with references, qrDistMat will be incomplete
if rlist != rlist_original:
rlist = rlist_original
qr_distMat = pp_sketchlib.queryDatabase(ref_db, query_db,
rlist, qlist, kmers,
True, False,
args.threads,
args.gpu_dist,
args.deviceid)
threads,
gpu_dist,
deviceid)

else:
qlist = None
Expand All @@ -206,12 +232,12 @@ def main():
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(rlist, rr_distMat,
qlist, qr_distMat, qq_distMat,
threads = args.threads)
threads = threads)

# extract subset of distances if requested
if args.include_files is not None:
if include_files is not None:
viz_subset = set()
with open(args.include_files, 'r') as assemblyFiles:
with open(include_files, 'r') as assemblyFiles:
for assembly in assemblyFiles:
viz_subset.add(assembly.rstrip())
if len(viz_subset.difference(combined_seq)) > 0:
Expand All @@ -224,26 +250,27 @@ def main():
qlist = list(viz_subset.intersection(qlist))
core_distMat = core_distMat[np.ix_(row_slice, row_slice)]
acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)]

else:
viz_subset = None
# Either use strain definitions, lineage assignments or external clustering
isolateClustering = {}
# Use external clustering if specified
if args.external_clustering:
cluster_file = args.external_clustering
if external_clustering:
cluster_file = external_clustering
isolateClustering = readIsolateTypeFromCsv(cluster_file,
mode = 'external',
return_dict = True)

# identify existing model and cluster files
if args.model_dir is not None:
model_prefix = args.model_dir
if model_dir is not None:
model_prefix = model_dir
else:
model_prefix = args.ref_db
model_file = model_prefix + "/" + os.path.basename(model_prefix)
model_prefix = ref_db
try:
model_file = os.path.basename(model_prefix) + "/" + os.path.basename(model_prefix)
model = loadClusterFit(model_file + '_fit.pkl',
model_file + '_fit.npz')
except:
except FileNotFoundError:
sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n')
sys.exit(1)

Expand All @@ -258,8 +285,8 @@ def main():
"visualisation only supports combined boundary fit\n")

# Set directories of previous fit
if args.previous_clustering is not None:
prev_clustering = args.previous_clustering
if previous_clustering is not None:
prev_clustering = previous_clustering
else:
prev_clustering = os.path.dirname(model_file)
cluster_file = prev_clustering + '/' + os.path.basename(prev_clustering) + suffix
Expand All @@ -268,37 +295,68 @@ def main():
return_dict = True)
# Join clusters with query clusters if required
if not self:
if args.previous_query_clustering is not None:
prev_query_clustering = args.previous_query_clustering
if previous_query_clustering is not None:
prev_query_clustering = previous_query_clustering + '/' + os.path.basename(previous_query_clustering)
else:
prev_query_clustering = args.query_db
prev_query_clustering = query_db

queryIsolateClustering = readIsolateTypeFromCsv(
prev_query_clustering + '/' + os.path.basename(prev_query_clustering) + suffix,
prev_query_clustering + suffix,
mode = mode,
return_dict = True)
isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering)

# Now have all the objects needed to generate selected visualisations
if args.microreact:
if microreact:
sys.stderr.write("Writing microreact output\n")
outputsForMicroreact(combined_seq, core_distMat, acc_distMat, isolateClustering, args.perplexity,
args.output, args.info_csv, args.rapidnj, queryList = qlist, overwrite = args.overwrite)
if args.phandango:
outputsForMicroreact(combined_seq, core_distMat, acc_distMat, isolateClustering, perplexity,
output, info_csv, rapidnj, queryList = qlist, overwrite = overwrite)
if phandango:
sys.stderr.write("Writing phandango output\n")
outputsForPhandango(combined_seq, core_distMat, isolateClustering, args.output, args.info_csv, args.rapidnj,
queryList = qlist, overwrite = args.overwrite, microreact = args.microreact)
if args.grapetree:
outputsForPhandango(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj,
queryList = qlist, overwrite = overwrite, microreact = microreact)
if grapetree:
sys.stderr.write("Writing grapetree output\n")
outputsForGrapetree(combined_seq, core_distMat, isolateClustering, args.output, args.info_csv, args.rapidnj,
queryList = qlist, overwrite = args.overwrite, microreact = args.microreact)
if args.cytoscape:
outputsForGrapetree(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj,
queryList = qlist, overwrite = overwrite, microreact = microreact)
if cytoscape:
sys.stderr.write("Writing cytoscape output\n")
genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, False, args.core_only, args.accessory_only)
outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv, viz_subset = viz_subset)
genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, False, core_only, accessory_only)
outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, viz_subset = viz_subset)
if model.type == 'lineage':
sys.stderr.write("Note: Only support for output of cytoscape graph at lowest rank\n")

sys.stderr.write("\nDone\n")

def main():
"""Main function. Parses cmd line args and runs in the specified mode.
"""
args = get_options()

generate_visualisations(args.query_db,
args.ref_db,
args.distances,
args.threads,
args.output,
args.gpu_dist,
args.deviceid,
args.external_clustering,
args.microreact,
args.phandango,
args.grapetree,
args.cytoscape,
args.perplexity,
args.strand_preserved,
args.include_files,
args.model_dir,
args.previous_clustering,
args.previous_query_clustering,
args.info_csv,
args.rapidnj,
args.overwrite,
args.core_only,
args.accessory_only)

if __name__ == '__main__':
main()

Expand Down

0 comments on commit 634394c

Please sign in to comment.