Skip to content

Commit

Permalink
Merge pull request #519 from GavinHuttley/develop
Browse files Browse the repository at this point in the history
bug fixes related to alignment quality and motif counting
  • Loading branch information
GavinHuttley committed Feb 5, 2020
2 parents 928888b + 579c3e1 commit 449cdaf
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 24 deletions.
35 changes: 26 additions & 9 deletions src/cogent3/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def counts_per_seq(
include_ambiguity
if True, motifs containing ambiguous characters
from the seq moltype are included. No expansion of those is attempted.
allow_gaps
allow_gap
if True, motifs containing a gap character are included.
Notes
Expand Down Expand Up @@ -2255,12 +2255,17 @@ def alignment_quality(self, equifreq_mprobs=True):
counts = self.counts_per_pos()
if counts.array.max() == 0 or len(self.seqs) == 1:
return None

motif_probs = self.get_motif_probs()

if equifreq_mprobs:
num_motifs = len(counts.motifs)
p = array([1 / num_motifs] * num_motifs)
else:
motif_probs = self.get_motif_probs()
p = array([motif_probs[b] for b in counts.motifs])
# we reduce motif_probs to observed states
motif_probs = {m: v for m, v in motif_probs.items() if v > 0}
num_motifs = len(motif_probs)
motif_probs = {m: 1 / num_motifs for m in motif_probs}

p = array([motif_probs.get(b, 0.0) for b in counts.motifs])

cols = p != 0
p = p[cols]
counts = counts.array[:, cols]
Expand Down Expand Up @@ -2988,17 +2993,29 @@ def counts_per_pos(

data = list(self.to_dict().values())
alpha = self.moltype.alphabet.get_word_alphabet(motif_length)
all_motifs = set() if allow_gap or include_ambiguity else None
all_motifs = set()
exclude_chars = set()
if not allow_gap:
exclude_chars.update(self.moltype.gap)

if not include_ambiguity:
ambigs = [c for c, v in self.moltype.ambiguities.items() if len(v) > 1]
exclude_chars.update(ambigs)

result = []
for i in range(0, len(self) - motif_length + 1, motif_length):
counts = CategoryCounter([s[i : i + motif_length] for s in data])
if all_motifs is not None:
all_motifs.update(list(counts))
all_motifs.update(list(counts))
result.append(counts)

if all_motifs:
alpha += tuple(sorted(set(alpha) ^ all_motifs))

if exclude_chars:
# this additional clause is required for the bytes moltype
# That moltype includes '-' as a character
alpha = [m for m in alpha if not (set(m) & exclude_chars)]

for i, counts in enumerate(result):
result[i] = counts.tolist(alpha)

Expand Down
88 changes: 73 additions & 15 deletions tests/test_core/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Sequence,
frac_same,
)
from cogent3.maths.util import safe_p_log_p
from cogent3.parse.fasta import MinimalFastaParser
from cogent3.util.misc import get_object_provenance
from cogent3.util.unit_test import TestCase, main
Expand Down Expand Up @@ -1538,12 +1539,16 @@ def test_alignment_quality(self):
"""Tests that the alignment_quality generates the right alignment quality
value based on the Hertz-Stormo metric. expected values are hand calculated
using the formula in the paper."""
aln = make_aligned_seqs(["AATTGA", "AGGTCC", "AGGATG", "AGGCGT"], moltype="dna")
aln = self.Class(["AATTGA", "AGGTCC", "AGGATG", "AGGCGT"], moltype="dna")
got = aln.alignment_quality(equifreq_mprobs=True)
expect = log2(4) + (3 / 2) * log2(3) + (1 / 2) * log2(2) + (1 / 2) * log2(2)
assert_allclose(got, expect)
# should be the same with the default moltype too
aln = self.Class(["AATTGA", "AGGTCC", "AGGATG", "AGGCGT"])
got = aln.alignment_quality(equifreq_mprobs=True)
assert_allclose(got, expect)

aln = make_aligned_seqs(["AAAC", "ACGC", "AGCC", "A-TC"], moltype="dna")
aln = self.Class(["AAAC", "ACGC", "AGCC", "A-TC"], moltype="dna")
got = aln.alignment_quality(equifreq_mprobs=False)
expect = (
2 * log2(1 / 0.4)
Expand All @@ -1553,24 +1558,23 @@ def test_alignment_quality(self):
)
assert_allclose(got, expect)

# 1. Alignment just gaps (Gap chars need to be fixed for unspecified moltype, before uncommenting).
# aln = make_aligned_seqs(["----"])
# got = aln.alignment_quality(equifreq_mprobs=True)
# assert_allclose(got, 0)
# 1. Alignment just gaps - alignment_quality returns None
aln = self.Class(["----", "----"])
got = aln.alignment_quality(equifreq_mprobs=True)
self.assertIsNone(got)

# 2 Just one sequence (I've made an assumption that if there is one sequence,
# the alignment quality should also return None, correct me if I'm wrong).
aln = make_aligned_seqs(["AAAC"])
# 2 Just one sequence - alignment_quality returns None
aln = self.Class(["AAAC"])
got = aln.alignment_quality(equifreq_mprobs=True)
assert got is None
self.assertIsNone(got)

# 3.1 Two seqs, one all gaps. (equifreq_mprobs=True)
aln = make_aligned_seqs(["----", "ACAT"])
aln = self.Class(["----", "ACAT"])
got = aln.alignment_quality(equifreq_mprobs=True)
assert_allclose(got, 28)
assert_allclose(got, 1.1699250014423124)

# 3.2 Two seqs, one all gaps. (equifreq_mprobs=False)
aln = make_aligned_seqs(["----", "AAAA"])
aln = self.Class(["----", "AAAA"])
got = aln.alignment_quality(equifreq_mprobs=False)
assert_allclose(got, -2)

Expand Down Expand Up @@ -2336,6 +2340,9 @@ def test_counts_per_pos(self):
aln = self.Class([s1, s2, s4], moltype=DNA)
obs = aln.counts_per_pos(allow_gap=True)
self.assertEqual(obs.array, exp_gap)
aln = self.Class(["-RAT", "ACCT", "GTGT"], moltype="dna")
c = aln.counts_per_pos(include_ambiguity=False, allow_gap=True)
self.assertEqual(set(c.motifs), set("ACGT-"))

def test_counts_per_seq_default_moltype(self):
"""produce correct counts per seq with default moltypes"""
Expand All @@ -2349,6 +2356,29 @@ def test_counts_per_seq_default_moltype(self):
got = coll.counts_per_seq(include_ambiguity=True, allow_gap=True)
self.assertEqual(got.col_sum()["-"], 2)

def test_counts_per_pos_default_moltype(self):
"""produce correct counts per pos with default moltypes"""
data = {"a": "AAAA??????", "b": "CCCGGG--NN", "c": "CCGGTTCCAA"}
coll = self.Class(data=data)
got = coll.counts_per_pos()
# should not include gap character
self.assertNotIn("-", got.motifs)
# allowing gaps
got = coll.counts_per_pos(allow_gap=True)
# should include gap character
self.assertEqual(got[5, "-"], 0)
self.assertEqual(got[6, "-"], 1)

# now with motif-length 2
got = coll.counts_per_pos(motif_length=2)
found_motifs = set()
lengths = set()
for m in got.motifs:
lengths.add(len(m))
found_motifs.update(m)
self.assertTrue("-" not in found_motifs)
self.assertEqual(lengths, {2})

def test_get_seq_entropy(self):
"""ArrayAlignment get_seq_entropy should get entropy of each seq"""
seqs = [AB.make_seq(s, preserve_case=True) for s in ["abab", "bbbb", "abbb"]]
Expand Down Expand Up @@ -3149,13 +3179,41 @@ def test_entropy_per_pos(self):
f = a.entropy_per_pos(allow_gap=True)
e = array([1.584962500721156, 1.584962500721156, 1.584962500721156, 0])
self.assertEqual(f, e)

seqs = []
for s in ["-RAT", "ACCT", "GTGT"]:
seqs.append(make_seq(s, moltype="dna"))
a = ArrayAlignment(seqs)

# "-RAT"
# "ACCT"
# "GTGT"
f = a.entropy_per_pos(allow_gap=False, include_ambiguity=False)
e = [
2 * safe_p_log_p(array([1 / 2])).sum(),
2 * safe_p_log_p(array([1 / 2])).sum(),
3 * safe_p_log_p(array([1 / 3])).sum(),
0,
]
assert_allclose(f, e)

f = a.entropy_per_pos(include_ambiguity=True)
e = array([1.584962500721156, 1.584962500721156, 1.584962500721156, 0])
self.assertEqual(f, e)
e = [
2 * safe_p_log_p(array([1 / 2])).sum(),
3 * safe_p_log_p(array([1 / 3])).sum(),
3 * safe_p_log_p(array([1 / 3])).sum(),
0,
]
assert_allclose(f, e)

f = a.entropy_per_pos(allow_gap=True)
e = [
3 * safe_p_log_p(array([1 / 3])).sum(),
2 * safe_p_log_p(array([1 / 2])).sum(),
3 * safe_p_log_p(array([1 / 3])).sum(),
0,
]
assert_allclose(f, e)

def test_coevolution_segments(self):
"""specifying coordinate segments produces matrix with just those"""
Expand Down

0 comments on commit 449cdaf

Please sign in to comment.