Skip to content

Commit

Permalink
Merge pull request #469 from GavinHuttley/develop
Browse files Browse the repository at this point in the history
Allow setting a ref_name preference for alignments
  • Loading branch information
GavinHuttley committed Jan 7, 2020
2 parents 4facf22 + a065a72 commit 8d37d75
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
26 changes: 18 additions & 8 deletions src/cogent3/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def __init__(
# both SequenceCollections and Alignments.
self._set_additional_attributes(curr_seqs)

self._repr_policy = dict(num_seqs=10, num_pos=60)
self._repr_policy = dict(num_seqs=10, num_pos=60, ref_name="longest")

def __str__(self):
"""Returns self in FASTA-format, respecting name order."""
Expand Down Expand Up @@ -1869,25 +1869,34 @@ def apply_pssm(

return array(result)

def set_repr_policy(self, num_seqs=None, num_pos=None):
def set_repr_policy(self, num_seqs=None, num_pos=None, ref_name=None):
"""specify policy for repr(self)
Parameters
----------
num_seqs
num_seqs : int or None
number of sequences to include in represented display.
num_pos
num_pos : int or None
length of sequences to include in represented display.
ref_name : str or None
name of sequence to be placed first, or "longest" (default).
If latter, indicates longest sequence will be chosen.
"""
if not any([num_seqs, num_pos]):
return
if num_seqs:
assert isinstance(num_seqs, int), "num_seqs is not an integer"
self._repr_policy["num_seqs"] = num_seqs

if num_pos:
assert isinstance(num_pos, int), "num_pos is not an integer"
self._repr_policy["num_pos"] = num_pos

if ref_name:
assert isinstance(ref_name, str), "ref_name is not a string"
if ref_name != "longest" and ref_name not in self.names:
raise ValueError(f"no sequence name matching {ref_name}")

self._repr_policy["ref_name"] = ref_name

def probs_per_seq(
self,
motif_length=1,
Expand Down Expand Up @@ -2680,7 +2689,6 @@ def _get_raw_pretty(self, name_order):
if name_order is not None:
assert set(name_order) <= set(self.names), "names don't match"

names = name_order or self.names
output = defaultdict(list)
names = name_order or self.names
num_seqs = len(names)
Expand Down Expand Up @@ -3681,13 +3689,15 @@ def __getitem__(self, item):
data = vstack(data)
else:
data = self.array_seqs[:, item]
return self.__class__(
result = self.__class__(
data.T,
list(map(str, self.names)),
self.alphabet,
conversion_f=aln_from_array,
info=self.info,
)
result._repr_policy.update(self._repr_policy)
return result

def _coerce_seqs(self, seqs, is_array):
"""Controls how seqs are coerced in _names_seqs_order.
Expand Down
2 changes: 2 additions & 0 deletions src/cogent3/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __getitem__(self, index):
new = self._mapped(map)
sliced_annots = self._sliced_annotations(new, map)
new.attach_annotations(sliced_annots)
if hasattr(self, "_repr_policy"):
new._repr_policy.update(self._repr_policy)
return new

def _mapped(self, map):
Expand Down
53 changes: 48 additions & 5 deletions tests/test_core/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,7 +1370,9 @@ def test_set_repr_policy_no_input(self):
"""repr_policy should remain unchanged"""
seqs = self.Class({"a": "AAAAA"})
seqs.set_repr_policy(num_seqs=None, num_pos=None)
self.assertEqual(seqs._repr_policy, dict(num_seqs=10, num_pos=60))
self.assertEqual(
seqs._repr_policy, dict(num_seqs=10, num_pos=60, ref_name="longest")
)

def test_set_repr_policy_invalid_input(self):
"""repr_policy should remain unchanged"""
Expand All @@ -1379,13 +1381,22 @@ def test_set_repr_policy_invalid_input(self):
seqs.set_repr_policy(num_seqs="foo", num_pos=4.2)
self.fail("Inputs not detected as invalid")
except AssertionError:
self.assertEqual(seqs._repr_policy, dict(num_seqs=10, num_pos=60))
self.assertEqual(
seqs._repr_policy, dict(num_seqs=10, num_pos=60, ref_name="longest")
)

def test_set_repr_policy_valid_input(self):
"""repr_policy should be set to new values"""
seqs = self.Class({"a": "AAAAA"})
seqs.set_repr_policy(num_seqs=5, num_pos=40)
self.assertEqual(seqs._repr_policy, dict(num_seqs=5, num_pos=40))
seqs = self.Class({"a": "AAAAA", "b": "AAA--"})
seqs.set_repr_policy(num_seqs=5, num_pos=40, ref_name="a")
self.assertEqual(seqs._repr_policy, dict(num_seqs=5, num_pos=40, ref_name="a"))
# should persist in slicing
if self.Class == SequenceCollection:
return True

self.assertEqual(
seqs[:2]._repr_policy, dict(num_seqs=5, num_pos=40, ref_name="a")
)

def test_get_seq_entropy(self):
"""get_seq_entropy should get entropy of each seq"""
Expand Down Expand Up @@ -2015,6 +2026,23 @@ def test_to_html(self):
self.assertTrue(other_row in got)
self.assertTrue(got.find(ref_row) < got.find(other_row))

# using different ref sequence
ref_row = (
'<tr><td class="label">seq2</td>'
'<td><span class="terminal_ambig_dna">-</span>'
'<span class="C_dna">C</span>'
'<span class="T_dna">T</span></td></tr>'
)
other_row = (
'<tr><td class="label">seq1</td>'
'<td><span class="A_dna">A</span>'
'<span class="C_dna">.</span>'
'<span class="G_dna">G</span></td></tr>'
)
got = aln.to_html(ref_name="seq2")
# order now changes
self.assertTrue(got.find(ref_row) < got.find(other_row))

def test_variable_positions(self):
"""correctly identify variable positions"""
new_seqs = {"seq1": "ACGTACGT", "seq2": "ACCGACGT", "seq3": "ACGTACGT"}
Expand Down Expand Up @@ -2380,6 +2408,21 @@ def test_seq_entropy_just_gaps(self):
entropy = a.entropy_per_seq()
self.assertIs(entropy, None)

def test_repr_html(self):
"""exercises method normally invoked in notebooks"""
aln = self.Class({"a": "AAAAA", "b": "AAA--"})
aln.set_repr_policy(num_seqs=5, num_pos=40)
self.assertEqual(aln[:3]._repr_policy, aln._repr_policy)
row_a = '<tr><td class="label">a</td>'
row_b = '<tr><td class="label">b</td>'
# default order is longest sequence at top
got = aln._repr_html_()
self.assertTrue(got.find(row_a) < got.find(row_b))
# change order, a should now be last
aln.set_repr_policy(num_seqs=5, num_pos=40, ref_name="b")
got = aln._repr_html_()
self.assertTrue(got.find(row_a) > got.find(row_b))


class ArrayAlignmentTests(AlignmentBaseTests, TestCase):
Class = ArrayAlignment
Expand Down

0 comments on commit 8d37d75

Please sign in to comment.