Skip to content

Commit

Permalink
Fix visualisation of individually-refined models
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjcroucher committed Jul 30, 2020
1 parent 7f9879d commit da5b1ac
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
31 changes: 22 additions & 9 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def main():
missing_isolates = [refList[m] for m in networkMissing]
sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n")

fit_type = None
fit_type = model.type
isolateClustering = {fit_type: printClusters(genomeNetwork,
refList,
args.output + "/" + os.path.basename(args.output),
Expand Down Expand Up @@ -576,8 +576,7 @@ def main():
# extract limited references from clique by default
if not args.full_db:
newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output)
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
#
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
names_to_remove = [refList[n] for n in nodes_to_remove]
prune_distance_matrix(refList, names_to_remove, distMat,
args.output + "/" + os.path.basename(args.output) + ".dists")
Expand Down Expand Up @@ -684,19 +683,33 @@ def main():
model_prefix = args.ref_db
if args.model_dir is not None:
model_prefix = args.model_dir
model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl',
model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz')
try:
sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n')
model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl',
model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz')
except:
sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n')
exit()

# Set directories of previous fit
if args.previous_clustering is not None:
prev_clustering = args.previous_clustering
else:
prev_clustering = os.path.dirname(args.distances + ".pkl")

# load clustering
cluster_file = args.ref_db + '/' + args.ref_db + '_clusters.csv'
isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
if model.indiv_fitted:
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
isolateClustering['refine'] = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
isolateClustering['refine'] = isolateClustering['refine']['Cluster']
for type in ['accessory','core']:
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_' + type + '_clusters.csv'
isolateClustering[type] = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
isolateClustering[type] = isolateClustering[type]['Cluster']
else:
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)

# generate selected visualisations
if args.microreact:
sys.stderr.write("Writing microreact output\n")
Expand Down
8 changes: 6 additions & 2 deletions PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ def save(self):
np.savez(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.npz',
intercept=np.array([self.optimal_x, self.optimal_y]),
core_acc_intercepts=np.array([self.core_boundary, self.accessory_boundary]),
scale=self.scale)
scale=self.scale,
indiv_fitted=self.indiv_fitted)
with open(self.outPrefix + "/" + os.path.basename(self.outPrefix) + '_fit.pkl', 'wb') as pickle_file:
pickle.dump([None, self.type], pickle_file)

Expand All @@ -681,7 +682,10 @@ def load(self, fit_npz, fit_obj):
self.accessory_boundary = np.asscalar(fit_npz['core_acc_intercepts'][1])
self.scale = fit_npz['scale']
self.fitted = True
self.indiv_fitted = False # Do not output multiple microreacts
if 'indiv_fitted' in fit_npz:
self.indiv_fitted = fit_npz['indiv_fitted']
else:
self.indiv_fitted = False # historical behaviour for backward compatibility
if np.isnan(self.optimal_y) and np.isnan(self.accessory_boundary):
self.threshold = True

Expand Down

0 comments on commit da5b1ac

Please sign in to comment.