Skip to content

Commit

Permalink
Merge pull request #110 from johnlees/clique_speed
Browse files Browse the repository at this point in the history
Uses new algorithm for clique pruning, which has better scaling
  • Loading branch information
johnlees committed Nov 11, 2020
2 parents b3c2810 + 501fd87 commit c519f58
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 42 deletions.
4 changes: 2 additions & 2 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def main():
# (this no longer loses information and should generally be kept on)
if not args.full_db:
newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = \
extractReferences(genomeNetwork, refList, args.output)
extractReferences(genomeNetwork, refList, args.output, threads = args.threads)
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
names_to_remove = [refList[n] for n in nodes_to_remove]
# Save reference distances
Expand Down Expand Up @@ -980,7 +980,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
dbOrder = refList + queryList
newRepresentativesIndices, newRepresentativesNames, \
newRepresentativesFile, genomeNetwork = \
extractReferences(genomeNetwork, dbOrder, output, refList)
extractReferences(genomeNetwork, dbOrder, output, refList, threads = threads)
# intersection that maintains order
newQueries = [x for x in queryList if x in frozenset(newRepresentativesNames)]
genomeNetwork.save(output + "/" + os.path.basename(output) + '.refs_graph.gt', fmt = 'gt')
Expand Down
123 changes: 85 additions & 38 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from scipy.stats import rankdata
from tempfile import mkstemp, mkdtemp
from collections import defaultdict, Counter
from functools import partial
from multiprocessing import Pool

from .utils import iterDistRows
from .utils import listDistInts
Expand Down Expand Up @@ -81,8 +83,48 @@ def fetchNetwork(network_dir, model, refList,

return (genomeNetwork, cluster_file)

def getCliqueRefs(G, reference_indices = set()):
"""Recursively prune a network of its cliques. Returns one vertex from
a clique at each stage
def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
Args:
G (graph)
The graph to get clique representatives from
reference_indices (set)
The unique list of vertices being kept, to add to
"""
cliques = gt.max_cliques(G)
try:
# Get the first clique, and see if it has any members already
# contained in the vertex list
clique = frozenset(next(cliques))
if clique.isdisjoint(reference_indices):
reference_indices.add(list(clique)[0])

# Remove the clique, and prune the resulting subgraph (recursively)
subgraph = gt.GraphView(G, vfilt=[v not in clique for v in G.vertices()])
if subgraph.num_vertices() > 1:
getCliqueRefs(subgraph, reference_indices)
elif subgraph.num_vertices() == 1:
reference_indices.add(subgraph.get_vertices()[0])
except StopIteration:
pass
return reference_indices

def cliquePrune(component, graph, reference_indices, components_list):
"""Wrapper function around :func:`~getCliqueRefs` so it can be
called by a multiprocessing pool
"""
subgraph = gt.GraphView(graph, vfilt=components_list == component)
refs = reference_indices.copy()
if subgraph.num_vertices() <= 2:
refs.add(subgraph.get_vertices()[0])
ref_list = refs
else:
ref_list = getCliqueRefs(subgraph, refs)
return(list(ref_list))

def extractReferences(G, dbOrder, outPrefix, existingRefs = None, threads = 1):
"""Extract references for each cluster based on cliques
Writes chosen references to file by calling :func:`~writeReferences`
Expand All @@ -105,65 +147,70 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
"""
if existingRefs == None:
references = set()
reference_indices = []
reference_indices = set()
else:
references = set(existingRefs)
index_lookup = {v:k for k,v in enumerate(dbOrder)}
reference_indices = [index_lookup[r] for r in references]

# extract cliques from network
cliques_in_overall_graph = [c.tolist() for c in gt.max_cliques(G)]
# order list by size of clique
cliques_in_overall_graph.sort(key = len, reverse = True)
# iterate through cliques
for clique in cliques_in_overall_graph:
alreadyRepresented = 0
for node in clique:
if node in reference_indices:
alreadyRepresented = 1
break
if alreadyRepresented == 0:
reference_indices.append(clique[0])

# Find any clusters which are represented by multiple references
# First get cluster assignments
clusters_in_overall_graph = printClusters(G, dbOrder, printCSV=False)
# Construct a dict containing one empty set for each cluster
reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())]
# Iterate through references
for reference_index in reference_indices:
# Add references to the originally empty set for the appropriate cluster
# Allows enumeration of the number of references per cluster
reference_clusters_in_overall_graph[clusters_in_overall_graph[dbOrder[reference_index]]].add(reference_index)
reference_indices = set([index_lookup[r] for r in references])

# Each component is independent, so can be multithreaded
components = gt.label_components(G)[0].a

# Turn gt threading off and on again either side of the parallel loop
if gt.openmp_enabled():
gt.openmp_set_num_threads(1)

# Cliques are pruned, taking one reference from each, until none remain
with Pool(processes=threads) as pool:
ref_lists = pool.map(partial(cliquePrune,
graph=G,
reference_indices=reference_indices,
components_list=components),
set(components))
# Returns nested lists, which need to be flattened
reference_indices = set([entry for sublist in ref_lists for entry in sublist])

if gt.openmp_enabled():
gt.openmp_set_num_threads(threads)

# Use a vertex filter to extract the subgraph of refences
# as a graphview
reference_vertex = G.new_vertex_property('bool')
for n,vertex in enumerate(G.vertices()):
for n, vertex in enumerate(G.vertices()):
if n in reference_indices:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_ref = gt.GraphView(G, vfilt = reference_vertex)
G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
# Calculate component membership for reference graph
clusters_in_reference_graph = printClusters(G, dbOrder, printCSV=False)
# Record to which components references below in the reference graph
reference_clusters_in_reference_graph = {}

# Find any clusters which are represented by >1 references
# This creates a dictionary: cluster_id: set(ref_idx in cluster)
clusters_in_full_graph = printClusters(G, dbOrder, printCSV=False)
reference_clusters_in_full_graph = defaultdict(set)
for reference_index in reference_indices:
reference_clusters_in_reference_graph[dbOrder[reference_index]] = clusters_in_reference_graph[dbOrder[reference_index]]
reference_clusters_in_full_graph[clusters_in_full_graph[dbOrder[reference_index]]].add(reference_index)

# Calculate the component membership within the reference graph
ref_order = [name for idx, name in enumerate(dbOrder) if idx in frozenset(reference_indices)]
clusters_in_reference_graph = printClusters(G_ref, ref_order, printCSV=False)
# Record the components/clusters the references are in the reference graph
# dict: name: ref_cluster
reference_clusters_in_reference_graph = {}
for reference_name in ref_order:
reference_clusters_in_reference_graph[reference_name] = clusters_in_reference_graph[reference_name]

# Check if multi-reference components have been split as a validation test
# First iterate through clusters
network_update_required = False
for cluster in reference_clusters_in_overall_graph:
for cluster_id, ref_idxs in reference_clusters_in_full_graph.items():
# Identify multi-reference clusters by this length
if len(cluster) > 1:
check = list(cluster)
if len(ref_idxs) > 1:
check = list(ref_idxs)
# check if these are still in the same component in the reference graph
for i in range(len(check)):
component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]]
for j in range(i, len(check)):
for j in range(i + 1, len(check)):
# Add intermediate nodes
component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]]
if component_i != component_j:
Expand Down
13 changes: 11 additions & 2 deletions PopPUNK/reference_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,27 @@ def main():
# This is the same set of function calls for --fit-model when no --full-db in __main__.py
# Find refs and prune network
reference_indices, reference_names, refFileName, G_ref = \
extractReferences(genomeNetwork, refList, args.output)
extractReferences(genomeNetwork, refList, args.output, threads = args.threads)
G_ref.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt')

# Prune distances
nodes_to_remove = set(range(len(refList))).difference(reference_indices)
names_to_remove = [refList[n] for n in nodes_to_remove]
prune_distance_matrix(refList, nodes_to_remove, distMat,
prune_distance_matrix(refList, names_to_remove, distMat,
args.output + "/" + os.path.basename(args.output) + ".dists")

# 'Resketch'
if len(nodes_to_remove) > 0:
removeFromDB(args.ref_db, args.output, set(refList) - set(reference_names))

db_outfile = args.output + "/" + os.path.basename(args.output) + ".tmp.h5"
db_infile = args.output + "/" + os.path.basename(args.output) + ".h5"
if os.path.exists(db_infile):
sys.stderr.write("Sketch DB exists in " + args.output + "\n"
"Not overwriting. Output DB is: " +
db_outfile + "\n")
else:
os.rename(db_outfile, db_infile)
else:
sys.stderr.write("No sequences to remove\n")

Expand Down

0 comments on commit c519f58

Please sign in to comment.