Skip to content

Commit

Permalink
Merge pull request #108 from johnlees/lineage_sketchlib
Browse files Browse the repository at this point in the history
Use pp-sketchlib for lineage assignment code
  • Loading branch information
johnlees committed Nov 11, 2020
2 parents 90e558c + ebdf5b0 commit b3c2810
Show file tree
Hide file tree
Showing 13 changed files with 591 additions and 808 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.2.1'
__version__ = '2.3.0'
503 changes: 257 additions & 246 deletions PopPUNK/__main__.py

Large diffs are not rendered by default.

427 changes: 0 additions & 427 deletions PopPUNK/lineage_clustering.py

This file was deleted.

223 changes: 218 additions & 5 deletions PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import operator
import pickle
import shutil
import re
from sklearn import utils
import scipy.optimize
from scipy.spatial.distance import euclidean
from scipy import stats
from scipy.sparse import coo_matrix, bmat, find

import pp_sketchlib

Expand All @@ -41,7 +43,15 @@
from .refine import readManualStart
from .plot import plot_refined_results

def loadClusterFit(pkl_file, npz_file, outPrefix = "", max_samples=100000):
# lineage
from .plot import distHistogram
epsilon = 1e-10

# Format for rank fits
def rankFile(rank):
return('_rank' + str(rank) + '_fit.npz')

def loadClusterFit(pkl_file, npz_file, outPrefix = "", max_samples = 100000):
'''Call this to load a fitted model
Args:
Expand All @@ -58,23 +68,36 @@ def loadClusterFit(pkl_file, npz_file, outPrefix = "", max_samples=100000):
'''
with open(pkl_file, 'rb') as pickle_obj:
fit_object, fit_type = pickle.load(pickle_obj)
fit_data = np.load(npz_file)

if fit_type == 'lineage':
# Can't save multiple sparse matrices to the same file, so do some
# file name processing
fit_data = {}
for rank in fit_object[0]:
fit_file = os.path.basename(pkl_file)
prefix = re.match(r"^(.+)_fit\.pkl$", fit_file)
rank_file = os.path.dirname(pkl_file) + "/" + \
prefix.group(1) + rankFile(rank)
fit_data[rank] = scipy.sparse.load_npz(rank_file)
else:
fit_data = np.load(npz_file)

if fit_type == "bgmm":
sys.stderr.write("Loading BGMM 2D Gaussian model\n")
load_obj = BGMMFit(outPrefix, max_samples)
load_obj.load(fit_data, fit_object)
elif fit_type == "dbscan":
sys.stderr.write("Loading DBSCAN model\n")
load_obj = DBSCANFit(outPrefix, max_samples)
load_obj.load(fit_data, fit_object)
elif fit_type == "refine":
sys.stderr.write("Loading previously refined model\n")
load_obj = RefineFit(outPrefix)
load_obj.load(fit_data, fit_object)
elif fit_type == "lineage":
sys.stderr.write("Loading previously lineage cluster model\n")
load_obj = LineageFit(outPrefix, fit_object[0])
else:
raise RuntimeError("Undefined model type: " + str(fit_type))

load_obj.load(fit_data, fit_object)
return load_obj

class ClusterFit:
Expand Down Expand Up @@ -759,3 +782,193 @@ def assign(self, X, slope=None, cpus=1):
return y


class LineageFit(ClusterFit):
'''Class for fits using the lineage assignment model. Inherits from :class:`ClusterFit`.
Must first run either :func:`~LineageFit.fit` or :func:`~LineageFit.load` before calling
other functions
Args:
outPrefix (str)
The output prefix used for reading/writing
ranks (list)
The ranks used in the fit
'''

def __init__(self, outPrefix, ranks):
ClusterFit.__init__(self, outPrefix)
self.type = 'lineage'
self.preprocess = False
self.ranks = []
for rank in sorted(ranks):
if (rank < 1):
sys.stderr.write("Rank must be at least 1")
sys.exit(0)
else:
self.ranks.append(int(rank))


def fit(self, X, accessory, threads):
'''Extends :func:`~ClusterFit.fit`
Gets assignments by using nearest neigbours.
Args:
X (numpy.array)
The core and accessory distances to cluster. Must be set if
preprocess is set.
accessory (bool)
Use accessory rather than core distances
threads (int)
Number of threads to use
Returns:
y (numpy.array)
Cluster assignments of samples in X
'''
ClusterFit.fit(self, X)
sample_size = int(round(0.5 * (1 + np.sqrt(1 + 8 * X.shape[0]))))
if (max(self.ranks) >= sample_size):
sys.stderr.write("Rank must be less than the number of samples")
sys.exit(0)

if accessory:
self.dist_col = 1
else:
self.dist_col = 0

self.nn_dists = {}
for rank in self.ranks:
row, col, data = \
pp_sketchlib.sparsifyDists(
pp_sketchlib.longToSquare(X[:, [self.dist_col]], threads),
0,
rank,
threads
)
data = [epsilon if d < epsilon else d for d in data]
self.nn_dists[rank] = coo_matrix((data, (row, col)),
shape=(sample_size, sample_size),
dtype = X.dtype)

self.fitted = True

y = self.assign(min(self.ranks))
return y

def save(self):
'''Save the model to disk, as an npz and pkl (using outPrefix).'''
if not self.fitted:
raise RuntimeError("Trying to save unfitted model")
else:
for rank in self.ranks:
scipy.sparse.save_npz(
self.outPrefix + "/" + os.path.basename(self.outPrefix) + \
rankFile(rank),
self.nn_dists[rank])
with open(self.outPrefix + "/" + os.path.basename(self.outPrefix) + \
'_fit.pkl', 'wb') as pickle_file:
pickle.dump([[self.ranks, self.dist_col], self.type], pickle_file)

def load(self, fit_npz, fit_obj):
'''Load the model from disk. Called from :func:`~loadClusterFit`
Args:
fit_npz (dict)
Fit npz opened with :func:`numpy.load`
fit_obj (sklearn.mixture.BayesianGaussianMixture)
The saved fit object
'''
self.ranks, self.dist_col = fit_obj
self.nn_dists = fit_npz
self.fitted = True

def plot(self, X):
'''Extends :func:`~ClusterFit.plot`
Write a summary of the fit, and plot the results using
:func:`PopPUNK.plot.plot_results` and :func:`PopPUNK.plot.plot_contours`
Args:
X (numpy.array)
Core and accessory distances
'''
ClusterFit.plot(self, X)
for rank in self.ranks:
distHistogram(self.nn_dists[rank].data,
rank,
self.outPrefix + "/" + os.path.basename(self.outPrefix))

def assign(self, rank):
'''Get the edges for the network. A little different from other methods,
as it doesn't go through the long form distance vector (as coo_matrix
is basically already in the correct gt format)
Args:
rank (int)
Rank to assign at
Returns:
y (list of tuples)
Edges to include in network
'''
if not self.fitted:
raise RuntimeError("Trying to assign using an unfitted model")
else:
y = []
for row, col in zip(self.nn_dists[rank].row, self.nn_dists[rank].col):
y.append((row, col))

return y

def extend(self, qqDists, qrDists):
# Reshape qq and qr dist matrices
qqSquare = pp_sketchlib.longToSquare(qqDists[:, [self.dist_col]], 1)
qqSquare[qqSquare < epsilon] = epsilon

n_ref = self.nn_dists[self.ranks[0]].shape[0]
n_query = qqSquare.shape[1]
qrRect = qrDists[:, [self.dist_col]].reshape(n_query, n_ref)
qrRect[qrRect < epsilon] = epsilon

for rank in self.ranks:
# Add the matrices together to make a large square matrix
full_mat = bmat([[self.nn_dists[rank], qrRect.transpose()],
[qrRect, qqSquare]],
format = 'csr',
dtype = self.nn_dists[rank].dtype)

# Reapply the rank to each row, using sparse matrix functions
data = []
row = []
col = []
for row_idx in range(full_mat.shape[0]):
sample_row = full_mat.getrow(row_idx)
dist_row, dist_col, dist = find(sample_row)
dist = [epsilon if d < epsilon else d for d in dist]
dist_idx_sort = np.argsort(dist)

# Identical to C++ code in matrix_ops.cpp:sparsify_dists
neighbours = 0
prev_val = -1
for sort_idx in dist_idx_sort:
if row_idx == dist_col[sort_idx]:
continue
new_val = abs(dist[sort_idx] - prev_val) < epsilon
if (neighbours < rank or new_val):
data.append(dist[sort_idx])
row.append(row_idx)
col.append(dist_col[sort_idx])

if not new_val:
neighbours += 1
prev_val = data[-1]
else:
break

self.nn_dists[rank] = coo_matrix((data, (row, col)),
shape=(full_mat.shape[0], full_mat.shape[0]),
dtype = self.nn_dists[rank].dtype)

y = self.assign(min(self.ranks))
return y

65 changes: 20 additions & 45 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,29 +205,8 @@ def writeReferences(refList, outPrefix):

return refFileName

def writeDummyReferences(refList, outPrefix):
"""Writes chosen references to file, for use with mash
Gives sequence name twice
Args:
refList (list)
Reference names to write
outPrefix (str)
Prefix for output file (.refs will be appended)
Returns:
refFileName (str)
The name of the file references were written to
"""
# write references to file
refFileName = outPrefix + "/" + os.path.basename(outPrefix) + ".mash.refs"
with open(refFileName, 'w') as rFile:
for ref in refList:
rFile.write("\t".join([ref, ref]) + '\n')

return refFileName

def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
def constructNetwork(rlist, qlist, assignments, within_label,
summarise = True, edge_list = False):
"""Construct an unweighted, undirected network without self-loops.
Nodes are samples and edges where samples are within the same cluster
Expand Down Expand Up @@ -263,9 +242,14 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
vertex_labels.append(qlist)

# identify edges
for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)):
if assignment == within_label:
connections.append((ref, query))
if edge_list:
connections = assignments
else:
for assignment, (ref, query) in zip(assignments,
listDistInts(rlist, qlist,
self = self_comparison)):
if assignment == within_label:
connections.append((ref, query))

# build the graph
G = gt.Graph(directed = False)
Expand Down Expand Up @@ -313,21 +297,19 @@ def networkSummary(G):

return(components, density, transitivity, score)

def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
assignments, model, queryDB, queryQuery = False,
use_mash = False, threads = 1):
def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
assignments, model, queryDB, queryQuery = False,
threads = 1):
"""Finds edges between queries and items in the reference database,
and modifies the network to include them.
Args:
dbFuncs (list)
List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs`
rlist (list)
rList (list)
List of reference names
qList (list)
List of query names
qFile (list)
File of query sequences
G (graph)
Network to add to (mutated)
kmers (list)
Expand All @@ -341,11 +323,6 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
queryQuery (bool)
Add in all query-query distances
(default = False)
use_mash (bool)
Use the mash backend
no_stream (bool)
Don't stream mash output
(default = False)
threads (int)
Number of threads to use if new db created
Expand All @@ -355,10 +332,7 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
Query-query distances
"""
# initalise functions
readDBParams = dbFuncs['readDBParams']
constructDatabase = dbFuncs['constructDatabase']
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']

# initialise links data structure
new_edges = []
Expand All @@ -368,8 +342,8 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
qqDistMat = None

# store links for each query in a list of edge tuples
ref_count = len(rlist)
for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qList, self = False)):
ref_count = len(rList)
for assignment, (ref, query) in zip(assignments, listDistInts(rList, qList, self = False)):
if assignment == model.within_label:
# query index needs to be adjusted for existing vertices in network
new_edges.append((ref, query + ref_count))
Expand Down Expand Up @@ -425,13 +399,14 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
G.add_edge_list(new_edges)

# including the vertex ID property map
for i,q in enumerate(qList):
G.vp.id[i + len(rlist)] = q
for i, q in enumerate(qList):
G.vp.id[i + len(rList)] = q

return qqDistMat

def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
externalClusterCSV = None, printRef = True, printCSV = True, clustering_type = 'combined'):
externalClusterCSV = None, printRef = True, printCSV = True,
clustering_type = 'combined'):
"""Get cluster assignments
Also writes assignments to a CSV file
Expand Down

0 comments on commit b3c2810

Please sign in to comment.