Skip to content

Commit

Permalink
Merge pull request #123 from johnlees/graph_weights
Browse files Browse the repository at this point in the history
Add option to save distances as edge weights in graph
  • Loading branch information
johnlees committed Nov 18, 2020
2 parents 2032f79 + 342db00 commit 776f077
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 22 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
# Minimum sketchlib version
SKETCHLIB_MAJOR = 1
SKETCHLIB_MINOR = 5
SKETCHLIB_PATCH = 3
SKETCHLIB_PATCH = 3
15 changes: 13 additions & 2 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_options():
oGroup.add_argument('--plot-fit', help='Create this many plots of some fits relating k-mer to core/accessory distances '
'[default = 0]', default=0, type=int)
oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true')
oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true')

# comparison metrics
kmerGroup = parser.add_argument_group('Create DB options')
Expand Down Expand Up @@ -385,23 +386,33 @@ def main():
#* *#
#******************************#
if model.type != "lineage":
if args.graph_weights:
weights = distMat
else:
weights = None
genomeNetwork = \
constructNetwork(refList,
queryList,
assignments,
model.within_label)
model.within_label,
weights=weights)
else:
# Lineage fit requires some iteration
indivNetworks = {}
lineage_clusters = defaultdict(dict)
for rank in sorted(rank_list):
sys.stderr.write("Network for rank " + str(rank) + "\n")
if args.graph_weights:
weights = model.edge_weights(rank)
else:
weights = None
indivNetworks[rank] = constructNetwork(
refList,
refList,
assignments[rank],
0,
edge_list = True
edge_list=True,
weights=weights
)
lineage_clusters[rank] = \
printClusters(indivNetworks[rank],
Expand Down
17 changes: 15 additions & 2 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def assign_query(dbFuncs,
threads,
overwrite,
plot_fit,
graph_weights,
max_a_dist,
model_dir,
strand_preserved,
Expand Down Expand Up @@ -162,11 +163,16 @@ def assign_query(dbFuncs,
for rank in model.ranks:
assignment = model.assign(rank)
# Overwrite the network loaded above
if graph_weights:
weights = model.edge_weights(rank)
else:
weights = None
genomeNetwork[rank] = constructNetwork(rNames + qNames,
rNames + qNames,
assignment,
0,
edge_list = True)
edge_list = True,
weights=weights)

isolateClustering[rank] = \
printClusters(genomeNetwork[rank],
Expand All @@ -189,11 +195,16 @@ def assign_query(dbFuncs,
queryAssignments = model.assign(qrDistMat)

# Assign clustering by adding to network
if graph_weights:
weights = qrDistMat
else:
weights = None
qqDistMat = \
addQueryToNetwork(dbFuncs, refList, queryList,
genomeNetwork, kmers,
queryAssignments, model, output, update_db,
strand_preserved, threads)
strand_preserved,
weights = weights, threads = threads)

isolateClustering = \
{'combined': printClusters(genomeNetwork, refList + queryList,
Expand Down Expand Up @@ -301,6 +312,7 @@ def get_options():
default=False, action='store_true')
oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true')
oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true')
oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true')

# comparison metrics
kmerGroup = parser.add_argument_group('Kmer comparison options')
Expand Down Expand Up @@ -397,6 +409,7 @@ def main():
args.threads,
args.overwrite,
args.plot_fit,
args.graph_weights,
args.max_a_dist,
args.model_dir,
args.strand_preserved,
Expand Down
15 changes: 15 additions & 0 deletions PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,21 @@ def assign(self, rank):

return y

def edge_weights(self, rank):
'''Get the distances for each edge returned by assign
Args:
rank (int)
Rank assigned at
Returns:
weights (list)
Distance for each assignment
'''
if not self.fitted:
raise RuntimeError("Trying to get weights from an unfitted model")
else:
return (self.nn_dists[rank].data)

def extend(self, qqDists, qrDists):
# Reshape qq and qr dist matrices
qqSquare = pp_sketchlib.longToSquare(qqDists[:, [self.dist_col]], 1)
Expand Down
75 changes: 60 additions & 15 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def writeReferences(refList, outPrefix):
return refFileName

def constructNetwork(rlist, qlist, assignments, within_label,
summarise = True, edge_list = False):
summarise = True, edge_list = False, weights = None):
"""Construct an unweighted, undirected network without self-loops.
Nodes are samples and edges where samples are within the same cluster
Expand All @@ -283,6 +283,11 @@ def constructNetwork(rlist, qlist, assignments, within_label,
Whether to calculate and print network summaries with :func:`~networkSummary`
(default = True)
edge_list (bool)
Whether input is edges, tuples of (v1, v2). Used with lineage assignment
weights (numpy.array)
If passed, the core,accessory distances for each assignment, which will
be annotated as an edge attribute
Returns:
G (graph)
Expand All @@ -300,18 +305,34 @@ def constructNetwork(rlist, qlist, assignments, within_label,

# identify edges
if edge_list:
connections = assignments
if weights is not None:
connections = []
for weight, (ref, query) in zip(weights, assignments):
connections.append((ref, query, weight))
else:
connections = assignments
else:
for assignment, (ref, query) in zip(assignments,
listDistInts(rlist, qlist,
self = self_comparison)):
for row_idx, (assignment, (ref, query)) in enumerate(zip(assignments,
listDistInts(rlist, qlist,
self = self_comparison))):
if assignment == within_label:
connections.append((ref, query))
if weights is not None:
dist = np.linalg.norm(weights[row_idx, :])
edge_tuple = (ref, query, dist)
else:
edge_tuple = (ref, query)
connections.append(edge_tuple)

# build the graph
G = gt.Graph(directed = False)
G.add_vertex(len(vertex_labels))
G.add_edge_list(connections)

if weights is not None:
eweight = G.new_ep("float")
G.add_edge_list(connections, eprops = [eweight])
G.edge_properties["weight"] = eweight
else:
G.add_edge_list(connections)

# add isolate ID to network
vid = G.new_vertex_property('string',
Expand Down Expand Up @@ -356,7 +377,7 @@ def networkSummary(G):

def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
assignments, model, queryDB, queryQuery = False,
strand_preserved = False, threads = 1):
strand_preserved = False, weights = None, threads = 1):
"""Finds edges between queries and items in the reference database,
and modifies the network to include them.
Expand Down Expand Up @@ -384,6 +405,9 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
Whether to treat strand as known (i.e. ignore rc k-mers)
when adding random distances. Only used if queryQuery = True
[default = False]
weights (numpy.array)
If passed, the core,accessory distances for each assignment, which will
be annotated as an edge attribute
threads (int)
Number of threads to use if new db created
Expand All @@ -404,10 +428,15 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,

# store links for each query in a list of edge tuples
ref_count = len(rList)
for assignment, (ref, query) in zip(assignments, listDistInts(rList, qList, self = False)):
for row_idx, (assignment, (ref, query)) in enumerate(zip(assignments, listDistInts(rList, qList, self = False))):
if assignment == model.within_label:
# query index needs to be adjusted for existing vertices in network
new_edges.append((ref, query + ref_count))
if weights is not None:
dist = np.linalg.norm(weights[row_idx, :])
edge_tuple = (ref, query + ref_count, dist)
else:
edge_tuple = (ref, query + ref_count)
new_edges.append(edge_tuple)
assigned.add(qList[query])

# Calculate all query-query distances too, if updating database
Expand All @@ -424,9 +453,14 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
threads = threads)

queryAssignation = model.assign(qqDistMat)
for assignment, (ref, query) in zip(queryAssignation, listDistInts(qList, qList, self = True)):
for row_idx, (assignment, (ref, query)) in enumerate(zip(queryAssignation, listDistInts(qList, qList, self = True))):
if assignment == model.within_label:
new_edges.append((ref + ref_count, query + ref_count))
if weights is not None:
dist = np.linalg.norm(qqDistMat[row_idx, :])
edge_tuple = (ref + ref_count, query + ref_count, dist)
else:
edge_tuple = (ref + ref_count, query + ref_count)
new_edges.append(edge_tuple)

# Otherwise only calculate query-query distances for new clusters
else:
Expand All @@ -453,13 +487,24 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
# identify any links between queries and store in the same links dict
# links dict now contains lists of links both to original database and new queries
# have to use names and link to query list in order to match to node indices
for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self = True)):
for row_idx, (assignment, (query1, query2)) in enumerate(zip(queryAssignation, iterDistRows(qlist1, qlist2, self = True))):
if assignment == model.within_label:
new_edges.append((query_indices[query1], query_indices[query2]))
if weights is not None:
dist = np.linalg.norm(qqDistMat[row_idx, :])
edge_tuple = (query_indices[query1], query_indices[query2], dist)
else:
edge_tuple = (query_indices[query1], query_indices[query2])
new_edges.append(edge_tuple)

# finish by updating the network
G.add_vertex(len(qList))
G.add_edge_list(new_edges)

if weights is not None:
eweight = G.new_ep("float")
G.add_edge_list(new_edges, eprops = [eweight])
G.edge_properties["weight"] = eweight
else:
G.add_edge_list(new_edges)

# including the vertex ID property map
for i, q in enumerate(qList):
Expand Down
82 changes: 82 additions & 0 deletions scripts/poppunk_add_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
# vim: set fileencoding=<utf-8> :
# Copyright 2018-2020 John Lees and Nick Croucher

import sys
import argparse
import pickle
import numpy as np

# command line parsing
def get_options():

parser = argparse.ArgumentParser(description='Add edge weights to a PopPUNK graph',
prog='add_weights')

# input options
parser.add_argument('graph', help='Input graph (.gt)')
parser.add_argument('distances', help='Prefix for distances (<name>.dists)')
parser.add_argument('output', help='Prefix for output graph')

parser.add_argument('--graphml', help='Save output as graphml file',
default=False, action='store_true')

return parser.parse_args()

def quit_msg(message):
sys.stderr.write(str(message) + "\n")
sys.exit(1)

# Convert an i, j square index to long form (see sketchlib for info)
def square_to_condensed(i, j, n):
assert (j > i)
return n*i - ((i*(i+1)) >> 1) + j - 1 - i

# main code
if __name__ == "__main__":

# Get command line options
args = get_options()

# Load network
import graph_tool.all as gt
G = gt.load_graph(args.graph)
if "weight" in G.edge_properties:
quit_msg("Graph already contains weights")

# Load dists
with open(args.distances + ".pkl", 'rb') as pickle_file:
rlist, qlist, self = pickle.load(pickle_file)
if not self:
quit_msg("Distances are from query mode")
dist_mat = np.load(args.distances + ".npy")

# Check network and dists are compatible
network_labels = G.vertex_properties["id"]
if set(network_labels) != set(rlist):
quit_msg("Names in distances do not match those in graph")
n = G.num_vertices()
assert(0.5 * n * (n-1) == dist_mat.shape[0])

# Match dist row order with network order
network_order = list(network_labels)
if rlist != network_order:
v_idx = [network_order.index(rname) for rname in rlist]
else:
v_idx = range(n)

eprop = G.new_edge_property("float")
for edge in G.edges():
v1, v2 = sorted(tuple(edge))
row_idx = square_to_condensed(v_idx[int(v1)], v_idx[int(v2)], n)
dist = np.linalg.norm(dist_mat[row_idx, :])
eprop[edge] = dist

# Add as edge attribute
G.edge_properties["weight"] = eprop
if args.graphml:
G.save(args.output + ".graphml", fmt="graphml")
else:
G.save(args.output + ".gt")

sys.exit(0)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def find_version(*file_paths):
'scripts/poppunk_extract_components.py',
'scripts/poppunk_calculate_silhouette.py',
'scripts/poppunk_extract_distances.py',
'scripts/poppunk_add_weights.py',
'scripts/poppunk_pickle_fix.py'],
install_requires=['numpy',
'scipy',
Expand Down
4 changes: 2 additions & 2 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#fit GMM
sys.stderr.write("Running DBSCAN model fit (--fit-model dbscan)\n")
subprocess.run("python ../poppunk-runner.py --fit-model dbscan --distances example_db/example_db.dists --ref-db example_db --output example_dbscan --overwrite", shell=True, check=True)
subprocess.run("python ../poppunk-runner.py --fit-model dbscan --distances example_db/example_db.dists --ref-db example_db --output example_dbscan --overwrite --graph-weights", shell=True, check=True)

#refine model with GMM
sys.stderr.write("Running model refinement (--fit-model refine)\n")
Expand All @@ -45,7 +45,7 @@
#assign query
sys.stderr.write("Running query assignment\n")
subprocess.run("python ../poppunk_assign-runner.py --q-files some_queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query --overwrite", shell=True, check=True)
subprocess.run("python ../poppunk_assign-runner.py --q-files some_queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query_update --update-db --overwrite", shell=True, check=True)
subprocess.run("python ../poppunk_assign-runner.py --q-files some_queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query_update --update-db --graph-weights --overwrite", shell=True, check=True)
subprocess.run("python ../poppunk_assign-runner.py --q-files single_query.txt --distances example_db/example_db.dists --ref-db example_db --output example_single_query --update-db --overwrite", shell=True, check=True)
subprocess.run("python ../poppunk_assign-runner.py --q-files some_queries.txt --distances example_db/example_db.dists --ref-db example_db --model-dir example_lineages --output example_lineage_query --overwrite", shell=True, check=True)

Expand Down

0 comments on commit 776f077

Please sign in to comment.