Skip to content

Commit

Permalink
Deal with paths in output correctly
Browse files Browse the repository at this point in the history
Closes #25
  • Loading branch information
johnlees committed Jun 26, 2018
1 parent f0a98ac commit 506e972
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 65 deletions.
45 changes: 25 additions & 20 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def main():
refList, queryList, distMat = queryDatabase(args.r_files, kmers, args.output, args.output, True,
args.plot_fit, args.no_stream, args.mash, args.threads)

dists_out = args.output + "/" + args.output + ".dists"
dists_out = args.output + "/" + os.path.basename(args.output) + ".dists"
storePickle(refList, queryList, True, distMat, dists_out)
else:
sys.stderr.write("Need to provide a list of reference files with --r-files")
Expand Down Expand Up @@ -249,8 +249,8 @@ def main():
model_prefix = args.ref_db
if args.model_dir is not None:
model_prefix = args.model_dir
old_model = loadClusterFit(model_prefix + "/" + model_prefix + '_fit.pkl',
model_prefix + "/" + model_prefix + '_fit.npz')
old_model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl',
model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz')
if old_model.type == 'refine':
sys.stderr.write("Model needs to be from --fit-model not --refine-model\n")
sys.exit(1)
Expand All @@ -273,16 +273,18 @@ def main():
fit_type = 'combined'
model.save()
genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
isolateClustering = {fit_type: printClusters(genomeNetwork, args.output + "/" + args.output)}
isolateClustering = {fit_type: printClusters(genomeNetwork, args.output + "/" + os.path.basename(args.output))}

# Write core and accessory based clusters, if they worked
if model.indiv_fitted:
indivNetworks = {}
for dist_type, slope in zip(['core', 'accessory'], [0, 1]):
indivAssignments = model.assign(distMat, slope)
indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label)
isolateClustering[dist_type] = printClusters(indivNetworks[dist_type], args.output + "/" + args.output + "_" + dist_type)
nx.write_gpickle(indivNetworks[dist_type], args.output + "/" + args.output + "_" + dist_type + '_graph.gpickle')
isolateClustering[dist_type] = printClusters(indivNetworks[dist_type],
args.output + "/" + os.path.basename(args.output) + "_" + dist_type)
nx.write_gpickle(indivNetworks[dist_type], args.output + "/" + os.path.basename(args.output) +
"_" + dist_type + '_graph.gpickle')
if args.core_only:
fit_type = 'core'
genomeNetwork = indivNetworks['core']
Expand Down Expand Up @@ -323,13 +325,14 @@ def main():
newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, args.output)
nodes_to_remove = set(refList).difference(newReferencesNames)
genomeNetwork.remove_nodes_from(nodes_to_remove)
prune_distance_matrix(refList, nodes_to_remove, distMat, args.output + "/" + args.output + ".dists")
prune_distance_matrix(refList, nodes_to_remove, distMat,
args.output + "/" + os.path.basename(args.output) + ".dists")
# Read previous database
kmers, sketch_sizes = readMashDBParams(ref_db, kmers, sketch_sizes)
constructDatabase(newReferencesFile, kmers, sketch_sizes, args.output, args.threads,
args.mash, True) # overwrite old db

nx.write_gpickle(genomeNetwork, args.output + "/" + args.output + '_graph.gpickle')
nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle')

elif args.assign_query:
if args.ref_db is not None and args.q_files is not None:
Expand Down Expand Up @@ -357,8 +360,8 @@ def main():
model_prefix = args.ref_db
if args.model_dir is not None:
model_prefix = args.model_dir
model = loadClusterFit(model_prefix + "/" + model_prefix + '_fit.pkl',
model_prefix + "/" + model_prefix + '_fit.npz')
model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl',
model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz')
queryAssignments = model.assign(distMat)

# Set directories of previous fit
Expand All @@ -370,15 +373,15 @@ def main():
# If a refined fit, may use just core or accessory distances
if args.core_only and model.type == 'refine':
model.slope = 0
old_network_file = prev_clustering + "/" + prev_clustering + '_core_graph.gpickle'
old_cluster_file = prev_clustering + "/" + prev_clustering + '_core_clusters.csv'
old_network_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_core_graph.gpickle'
old_cluster_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_core_clusters.csv'
elif args.accessory_only and model.type == 'refine':
model.slope = 1
old_network_file = prev_clustering + "/" + prev_clustering + '_accessory_graph.gpickle'
old_cluster_file = prev_clustering + "/" + prev_clustering + '_accessory_clusters.csv'
old_network_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_accessory_graph.gpickle'
old_cluster_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_accessory_clusters.csv'
else:
old_network_file = prev_clustering + "/" + prev_clustering + '_graph.gpickle'
old_cluster_file = prev_clustering + "/" + prev_clustering + '_clusters.csv'
old_network_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_graph.gpickle'
old_cluster_file = prev_clustering + "/" + os.path.basename(prev_clustering) + '_clusters.csv'
if args.core_only or args.accessory_only:
sys.stderr.write("Can only do --core-only or --accessory-only fits from "
"a refined fit. Using the combined distances.\n")
Expand All @@ -387,13 +390,15 @@ def main():
sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_nodes()) + " samples\n")

# Assign clustering by adding to network
ordered_queryList, query_distMat = addQueryToNetwork(refList, queryList, args.q_files, genomeNetwork, kmers, queryAssignments, model, args.output, args.no_stream, args.update_db, args.threads, args.mash)
ordered_queryList, query_distMat = addQueryToNetwork(refList, queryList, args.q_files,
genomeNetwork, kmers, queryAssignments, model, args.output, args.no_stream, args.update_db,
args.threads, args.mash)

# if running simple query
print_full_clustering = False
if args.update_db:
print_full_clustering = True
isolateClustering = {'combined': printClusters(genomeNetwork, args.output + "/" + args.output,
isolateClustering = {'combined': printClusters(genomeNetwork, args.output + "/" + os.path.basename(args.output),
old_cluster_file, print_full_clustering)}

# update_db like no full_db
Expand All @@ -404,7 +409,7 @@ def main():
newRepresentativesNames, newRepresentativesFile = extractReferences(genomeNetwork, args.output)
if args.full_db is False:
genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames))
nx.write_gpickle(genomeNetwork, args.output + "/" + args.output + '_graph.gpickle')
nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle')

# Update the mash database
newQueries = set(newRepresentativesNames).intersection(queryList)
Expand All @@ -418,7 +423,7 @@ def main():
combined_seq, core_distMat, acc_distMat = update_distance_matrices(refList, ref_distMat,
ordered_queryList, distMat, query_distMat)
complete_distMat = translate_distMat(combined_seq, core_distMat, acc_distMat)
dists_out = args.output + "/" + args.output + ".dists"
dists_out = args.output + "/" + os.path.basename(args.output) + ".dists"
storePickle(combined_seq, combined_seq, True, complete_distMat, dists_out)


Expand Down
22 changes: 11 additions & 11 deletions PopPUNK/mash.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def getDatabaseName(prefix, k):
db_name (str)
Name of mash db
"""
return prefix + "/" + prefix + "." + k + ".msh"
return prefix + "/" + os.path.basename(prefix) + "." + k + ".msh"


def createDatabaseDir(outPrefix, kmers):
Expand All @@ -68,7 +68,7 @@ def createDatabaseDir(outPrefix, kmers):
# check for writing
if os.path.isdir(outputDir):
# remove old database files if not needed
for msh_file in glob(outputDir + "/" + outPrefix + "*.msh"):
for msh_file in glob(outputDir + "/" + os.path.basename(outPrefix) + "*.msh"):
knum = int(msh_file.split('.')[-2])
if not (kmers == knum).any():
sys.stderr.write("Removing old database " + msh_file + "\n")
Expand Down Expand Up @@ -104,7 +104,7 @@ def getSketchSize(dbPrefix, klist, mash_exec = 'mash'):

# iterate over kmer lengths
for k in klist:
dbname = "./" + dbPrefix + "/" + dbPrefix + "." + str(k) + ".msh"
dbname = "./" + dbPrefix + "/" + os.path.basename(dbPrefix) + "." + str(k) + ".msh"
try:
mash_cmd = mash_exec + " info -t " + dbname
mash_info = subprocess.Popen(mash_cmd, universal_newlines=True, shell=True, stdout=subprocess.PIPE)
Expand Down Expand Up @@ -191,9 +191,9 @@ def joinDBs(db1, db2, klist, mash_exec = 'mash'):
"""
for kmer in klist:
try:
join_name = db1 + "/" + db1 + "." + str(kmer) + ".joined"
db1_name = db1 + "/" + db1 + "." + str(kmer) + ".msh"
db2_name = db2 + "/" + db2 + "." + str(kmer) + ".msh"
join_name = db1 + "/" + os.path.basename(db1) + "." + str(kmer) + ".joined"
db1_name = db1 + "/" + os.path.basename(db1) + "." + str(kmer) + ".msh"
db2_name = db2 + "/" + os.path.basename(db2) + "." + str(kmer) + ".msh"

mash_cmd = mash_exec + " paste " + join_name + " " + db1_name + " " + db2_name
subprocess.run(mash_cmd, shell=True, check=True)
Expand Down Expand Up @@ -301,7 +301,7 @@ def runSketch(k, assemblyList, sketch, genome_length, oPrefix, mash_exec = 'mash
(default = 1)
"""
# define database name
dbname = "./" + oPrefix + "/" + oPrefix + "." + str(k)
dbname = "./" + oPrefix + "/" + os.path.basename(oPrefix) + "." + str(k)
dbfilename = dbname + ".msh"

# calculate false positive rate
Expand Down Expand Up @@ -395,7 +395,7 @@ def queryDatabase(qFile, klist, dbPrefix, queryPrefix, self = True, number_plot_
with open(qFile, 'r') as queryFile:
for line in queryFile:
queryList.append(line.rstrip())
refList = getSeqsInDb("./" + dbPrefix + "/" + dbPrefix + "." + str(klist[0]) + ".msh", mash_exec)
refList = getSeqsInDb(dbPrefix + "/" + os.path.basename(dbPrefix) + "." + str(klist[0]) + ".msh", mash_exec)

if self:
if dbPrefix != queryPrefix:
Expand All @@ -412,8 +412,8 @@ def queryDatabase(qFile, klist, dbPrefix, queryPrefix, self = True, number_plot_
row = 0

# run mash distance query based on current file
ref_dbname = "./" + dbPrefix + "/" + dbPrefix + "." + str(k) + ".msh"
query_dbname = "./" + queryPrefix + "/" + queryPrefix + "." + str(k) + ".msh"
ref_dbname = dbPrefix + "/" + os.path.basename(dbPrefix) + "." + str(k) + ".msh"
query_dbname = queryPrefix + "/" + os.path.basename(queryPrefix) + "." + str(k) + ".msh"
# construct mash command
mash_cmd = mash_exec + " dist -p " + str(threads) + " " + ref_dbname + " " + query_dbname

Expand Down Expand Up @@ -613,7 +613,7 @@ def getKmersFromReferenceDatabase(dbPrefix):
"""
# prepare
knum = []
fullDbPrefix = "./" + dbPrefix + "/" + dbPrefix + "."
fullDbPrefix = dbPrefix + "/" + os.path.basename(dbPrefix) + "."

# iterate through files
for msh_file in glob(fullDbPrefix + "*.msh"):
Expand Down
20 changes: 10 additions & 10 deletions PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fit(self, X = None):
self.subsampled_X /= self.scale

# Show clustering
plot_scatter(self.subsampled_X, self.outPrefix + "/" + self.outPrefix + "_distanceDistribution",
plot_scatter(self.subsampled_X, self.outPrefix + "/" + os.path.basename(self.outPrefix) + "_distanceDistribution",
self.outPrefix + " distances")

def no_scale(self):
Expand Down Expand Up @@ -181,14 +181,14 @@ def save(self):
if not self.fitted:
raise RuntimeError("Trying to save unfitted model")
else:
np.savez(self.outPrefix + "/" + self.outPrefix + '_fit.npz',
np.savez(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.npz',
weights=self.weights,
means=self.means,
covariances=self.covariances,
within=self.within_label,
between=self.between_label,
scale=self.scale)
with open(self.outPrefix + "/" + self.outPrefix + '_fit.pkl', 'wb') as pickle_file:
with open(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.pkl', 'wb') as pickle_file:
pickle.dump([self.dpgmm, self.type], pickle_file)


Expand Down Expand Up @@ -230,7 +230,7 @@ def plot(self, X, y):
"\tNumber of components used\t" + str(used_components)]) + "\n")

title = self.outPrefix + " " + str(len(np.unique(y))) + "-component DPGMM"
outfile = self.outPrefix + "/" + self.outPrefix + "_DPGMM_fit"
outfile = self.outPrefix + "/" + os.path.basename(self.outPrefix) + "_DPGMM_fit"

plot_results(X, y, self.means, self.covariances, self.scale, title, outfile)
plot_contours(y, self.weights, self.means, self.covariances, title + " assignment boundary", outfile + "_contours")
Expand Down Expand Up @@ -351,15 +351,15 @@ def save(self):
if not self.fitted:
raise RuntimeError("Trying to save unfitted model")
else:
np.savez(self.outPrefix + "/" + self.outPrefix + '_fit.npz',
np.savez(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.npz',
n_clusters=self.n_clusters,
within=self.within_label,
between=self.between_label,
means=self.cluster_means,
maxs=self.cluster_maxs,
mins=self.cluster_mins,
scale=self.scale)
with open(self.outPrefix + "/" + self.outPrefix + '_fit.pkl', 'wb') as pickle_file:
with open(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.pkl', 'wb') as pickle_file:
pickle.dump([self.hdb, self.type], pickle_file)


Expand Down Expand Up @@ -402,7 +402,7 @@ def plot(self):
"\tNumber of assignments\t" + str(len(self.labels))]) + "\n")

plot_dbscan_results(self. subsampled_X, self. labels, self.n_clusters,
self.outPrefix + "/" + self.outPrefix + "_dbscan")
self.outPrefix + "/" + os.path.basename(self.outPrefix) + "_dbscan")


def assign(self, X):
Expand Down Expand Up @@ -553,11 +553,11 @@ def save(self):
if not self.fitted:
raise RuntimeError("Trying to save unfitted model")
else:
np.savez(self.outPrefix + "/" + self.outPrefix + '_fit.npz',
np.savez(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.npz',
intercept=np.array([self.optimal_x, self.optimal_y]),
core_acc_intercepts=np.array([self.core_boundary, self.accessory_boundary]),
scale=self.scale)
with open(self.outPrefix + "/" + self.outPrefix + '_fit.pkl', 'wb') as pickle_file:
with open(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.pkl', 'wb') as pickle_file:
pickle.dump([None, self.type], pickle_file)


Expand Down Expand Up @@ -593,7 +593,7 @@ def plot(self, X):
plot_refined_results(X, self.assign(X), self.optimal_x, self.optimal_y, self.core_boundary,
self.accessory_boundary, self.mean0, self.mean1, self.start_point, self.min_move,
self.max_move, self.scale, self.indiv_fitted, "Refined fit boundary",
self.outPrefix + "/" + self.outPrefix + "_refined_fit")
self.outPrefix + "/" + os.path.basename(self.outPrefix) + "_refined_fit")


def assign(self, X, slope=None):
Expand Down
4 changes: 2 additions & 2 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .mash import getDatabaseName
from .mash import getSketchSize

from .network import iterDistRows
from .utils import iterDistRows

def extractReferences(G, outPrefix):
"""Extract references for each cluster based on cliques
Expand Down Expand Up @@ -70,7 +70,7 @@ def writeReferences(refList, outPrefix):
The name of the file references were written to
"""
# write references to file
refFileName = "./" + outPrefix + "/" + outPrefix + ".refs"
refFileName = outPrefix + "/" + os.path.basename(outPrefix) + ".refs"
with open(refFileName, 'w') as rFile:
for ref in refList:
rFile.write(ref + '\n')
Expand Down

0 comments on commit 506e972

Please sign in to comment.