Skip to content

Commit

Permalink
Merge pull request #240 from bacpop/threshold-fix
Browse files Browse the repository at this point in the history
Fix for assign with threshold models
  • Loading branch information
johnlees committed Nov 4, 2022
2 parents 75cb2aa + 0a8e04d commit ff01731
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 5 additions & 5 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,15 @@ def assign_query_hdf5(dbFuncs,

else:
# Assign these distances as within or between strain
if fit_type == 'default':
queryAssignments = model.assign(qrDistMat)
dist_type = 'euclidean'
elif fit_type == 'core_refined':
if fit_type == 'core_refined' or model.threshold:
queryAssignments = model.assign(qrDistMat, slope = 0)
dist_type = 'core'
elif fit_type == 'accessory_refined':
queryAssignments = model.assign(qrDistMat, slope = 1)
dist_type = 'accessory'
else:
queryAssignments = model.assign(qrDistMat)
dist_type = 'euclidean'

# QC assignments to check for multi-links
if qc_dict['run_qc'] and qc_dict['max_merge'] > 1:
Expand All @@ -586,7 +586,7 @@ def assign_query_hdf5(dbFuncs,
failed_samples = frozenset(qNames) - seq_names_passing
if len(failed_samples) > 0:
sys.stderr.write(f"{len(failed_samples)} samples failed:\n"
f"{','.join(failed_samples)}\n")
f"{','.join(failed_samples)}\n")
if len(failed_samples) == len(qNames):
sys.exit(1)
else:
Expand Down
8 changes: 5 additions & 3 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,9 @@ def addQueryToNetwork(dbFuncs, rList, qList, G,

# do not calculate weights unless specified
if weights is None:
distance_type = None
weights_type = None
else:
weights_type = distance_type

# These are returned
qqDistMat = None
Expand All @@ -1312,7 +1314,7 @@ def addQueryToNetwork(dbFuncs, rList, qList, G,
previous_network = G,
old_ids = rList,
distMat = weights,
weights_type = distance_type,
weights_type = weights_type,
summarise = False,
use_gpu = use_gpu)

Expand Down Expand Up @@ -1357,7 +1359,7 @@ def addQueryToNetwork(dbFuncs, rList, qList, G,
old_ids = vertex_labels,
adding_qq_dists = True,
distMat = qqDistMat,
weights_type = distance_type,
weights_type = weights_type,
summarise = False,
use_gpu = use_gpu)

Expand Down

0 comments on commit ff01731

Please sign in to comment.