Skip to content

Commit

Permalink
Refactor compare_homolog_groups and fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
biologyguy committed May 4, 2018
1 parent 6d0f1a8 commit cc5376a
Showing 1 changed file with 60 additions and 46 deletions.
106 changes: 60 additions & 46 deletions rdmcl/compare_homolog_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,56 +52,60 @@ def __init__(self, true_clusters, query_clusters):
self._prepare_difference()

def _prepare_difference(self):
# For each query cluster, find the true cluster with the most overlap
final_clusters = [[] for _ in range(len(self.query_clusters))]
query_to_true = []
success_tally = 0
true_clusters = list(self.true_clusters)

for q_indx, q_cluster in enumerate(self.query_clusters):
intersections = [list(filter(lambda x: x in q_cluster, sublist)) for sublist in true_clusters]
# Make sure all sequences are present in both sets of clusters
true_seq_ids = set([seq_id for clust in self.true_clusters for seq_id in clust])
query_seq_ids = set([seq_id for clust in self.query_clusters for seq_id in clust])
unique = [seq_id for seq_id in true_seq_ids if seq_id not in query_seq_ids]
unique += [seq_id for seq_id in query_seq_ids if seq_id not in true_seq_ids]
if unique:
raise ValueError("Not all sequences are present in both sets of clusters")

# For each true cluster, find the query cluster with the most overlap
final_clusters = [[] for _ in range(len(self.true_clusters))]
true_to_query = []
query_clusters = list(self.query_clusters)

for t_indx, t_cluster in enumerate(self.true_clusters):
intersections = [list(filter(lambda x: x in t_cluster, sublist)) for sublist in query_clusters]

max_match = 0
max_match_indx = None
for true_indx, intersect in enumerate(intersections):
for query_indx, intersect in enumerate(intersections):
if len(intersect) > max_match:
max_match = len(intersect)
max_match_indx = true_indx
max_match_indx = query_indx

query_to_true.append([q_indx, max_match_indx])
for seq_id in q_cluster:
if max_match_indx is None or seq_id not in self.true_clusters[max_match_indx]:
true_to_query.append([t_indx, max_match_indx])
for seq_id in t_cluster:
if seq_id not in self.query_clusters[max_match_indx]:
if re.search("Mle", seq_id):
final_clusters[q_indx].append("\033[91m\033[4m%s\033[24m\033[39m" % seq_id)
final_clusters[t_indx].append("\033[91m\033[4m%s\033[24m\033[39m" % seq_id)
else:
final_clusters[q_indx].append("\033[91m%s\033[39m" % seq_id)
final_clusters[t_indx].append("\033[91m%s\033[39m" % seq_id)
else:
success_tally += 1
if re.search("Mle", seq_id):
final_clusters[q_indx].append("\033[92m\033[4m%s\033[24m\033[39m" % seq_id)
final_clusters[t_indx].append("\033[92m\033[4m%s\033[24m\033[39m" % seq_id)
else:
final_clusters[q_indx].append("\033[92m%s\033[39m" % seq_id)
if max_match_indx is not None:
true_clusters[max_match_indx] = []
final_clusters[t_indx].append("\033[92m%s\033[39m" % seq_id)

# Setting comparison stats, as explained in Lechner M. et al., 2014 PlosONE. DOI: 10.1371/journal.pone.0105015
# tp = True positive, fp = False positive, fn = False negative, tn = True negative
for q_indx, t_indx in query_to_true:
if t_indx is None:
tp = 0
fp = len(self.query_clusters[q_indx])
fn = 0
tn = self.total_size - fp
else:
tp = len(list(filter(lambda x: x in self.query_clusters[q_indx], self.true_clusters[t_indx])))
fp = len(self.query_clusters[q_indx]) - tp
fn = len([x for x in self.true_clusters[t_indx] if x not in self.query_clusters[q_indx]])
tn = self.total_size - tp - fp - fn

self.precision += (tp / (tp + fp)) * (len(self.query_clusters[q_indx]) / self.total_size)
self.recall += 0 if t_indx is None \
else (tp / (tp + fn)) * (len(self.query_clusters[q_indx]) / self.total_size)
self.accuracy += ((tp + tn) / self.total_size) * (len(self.query_clusters[q_indx]) / self.total_size)
self.tn_rate += (tn / (tn + fp)) * (len(self.query_clusters[q_indx]) / self.total_size)
sum_tp, sum_fp, sum_fn, sum_tn = 0, 0, 0, 0
for t_indx, q_indx in true_to_query:
tp = len(list(filter(lambda x: x in self.true_clusters[t_indx], self.query_clusters[q_indx])))
fp = len(self.query_clusters[q_indx]) - tp
fn = len([x for x in self.true_clusters[t_indx] if x not in self.query_clusters[q_indx]])
tn = self.total_size - tp - fp - fn

sum_tp += tp
sum_fp += fp
sum_fn += fn
sum_tn += tn

self.precision = (sum_tp / (sum_tp + sum_fp))
self.recall = (sum_tp / (sum_tp + sum_fn))
self.accuracy = ((sum_tp + sum_tn) / (sum_tp + sum_tn + sum_fp + sum_fn))
self.tn_rate = (sum_tn / (sum_tn + sum_fp))

query_parent = Cluster([_id for _ids in self.query_clusters for _id in _ids])
self.query_score = sum([Cluster(next_set, parent=query_parent).score() for next_set in self.query_clusters])
Expand All @@ -114,14 +118,17 @@ def _prepare_difference(self):
return

def score(self):
output = "Precision: %s%%\n" % (round(self.precision, 4) * 100)
output += "Recall: %s%%\n" % (round(self.recall, 4) * 100)
output += "Accuracy: %s%%\n" % (round(self.accuracy, 4) * 100)
output += "tn rate: %s%%\n" % (round(self.tn_rate, 4) * 100)
output = "Precision: %s%%\n" % (round(self.precision * 100, 2))
output += "Recall: %s%%\n" % (round(self.recall * 100, 2))
output += "Accuracy: %s%%\n" % (round(self.accuracy * 100, 2))
output += "tn rate: %s%%\n" % (round(self.tn_rate * 100, 2))
output += "Query score: %s\n" % round(self.query_score, 2)
output += "True score: %s\n" % round(self.true_score, 2)
return output

def __str__(self):
return self.score()


class Cluster(rdmcl.Cluster):
def __init__(self, seq_ids, parent=None, taxa_separator="-"):
Expand Down Expand Up @@ -181,12 +188,19 @@ def fmt(prog):

in_args = parser.parse_args()

comparison = Comparison(in_args.true_clusters, in_args.query_clusters)
try:
comparison = Comparison(in_args.true_clusters, in_args.query_clusters)
if in_args.score:
print(comparison.score())
else:
print(comparison.pretty_out)

except ValueError as err:
if "Not all sequences are present" not in str(err):
raise

if in_args.score:
print(comparison.score())
else:
print(comparison.pretty_out)
br._stderr(hlp.RED + "Error!" + hlp.END + "\n")
br._stderr("There are differences in the sequences present in your query and true clusters files.\n")


if __name__ == '__main__':
Expand Down

0 comments on commit cc5376a

Please sign in to comment.