Skip to content

Commit

Permalink
Merge pull request #480 from jamesmartini/develop_deepcopy476
Browse files Browse the repository at this point in the history
Added deepcopy method to _SequenceCollectionBase
  • Loading branch information
GavinHuttley committed Jan 15, 2020
2 parents fbadc0f + 8b734cb commit b5a8c2c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/cogent3/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ def __init__(
# if we're forcing the same data, skip the validation
if force_same_data:
self._force_same_data(data, names)
curr_seqs = data
if isinstance(data, dict):
curr_seqs = list(data.values())
else:
curr_seqs = data
# otherwise, figure out what we got and coerce it into the right type
else:
per_seq_names, curr_seqs, name_order = self._names_seqs_order(
Expand Down Expand Up @@ -546,6 +549,7 @@ def _set_additional_attributes(self, curr_seqs):

def _force_same_data(self, data, names):
"""Forces dict that was passed in to be used as self.named_seqs"""
assert isinstance(data, dict), "force_same_data requires input data is a dict"
self.named_seqs = data
self.names = names or list(data.keys())

Expand All @@ -554,6 +558,21 @@ def copy(self):
result = self.__class__(self, moltype=self.moltype, info=self.info)
return result

def deepcopy(self, sliced=True):
"""Returns deep copy of self."""
new_seqs = dict()
for seq in self.seqs:
try:
new_seq = seq.deepcopy(sliced=sliced)
except AttributeError:
new_seq = seq.copy()
new_seqs[seq.name] = new_seq

info = deepcopy(self.info)
result = self.__class__(new_seqs, moltype=self.moltype, info=info, force_same_data=True)
result._repr_policy.update(self._repr_policy)
return result

def _get_alphabet_and_moltype(self, alphabet, moltype, data):
"""Returns alphabet and moltype, giving moltype precedence."""
if type(moltype) == str:
Expand Down Expand Up @@ -3654,6 +3673,7 @@ def _force_same_data(self, data, names):
"""Forces array that was passed in to be used as selfarray_positions"""
if isinstance(data, ArrayAlignment):
data = data._positions

self.array_positions = data
self.names = names or self._make_names(len(data[0]))

Expand Down Expand Up @@ -4198,6 +4218,20 @@ def get_identical_sets(self, mask_degen=False):

return identical_sets

def deepcopy(self, sliced=True):
"""Returns deep copy of self."""
info = deepcopy(self.info)
positions = deepcopy(self.array_seqs)
result = self.__class__(
positions,
force_same_data=True,
moltype=self.moltype,
info=info,
names=self.names,
)
result._repr_policy.update(self._repr_policy)
return result


class CodonArrayAlignment(ArrayAlignment):
"""Stores alignment of gapped codons, no degenerate symbols."""
Expand Down
24 changes: 24 additions & 0 deletions tests/test_core/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def setUp(self):
self.b = Alignment(["AAA", "AAA"])
self.c = SequenceCollection(["AAA", "AAA"])

def test_deepcopy(self):
"""correctly deep copy aligned objects in an alignment"""
data = {"seq1": "ACGACGACG", "seq2": "ACGACGACG"}
seqs = self.Class(data)
copied = seqs.deepcopy(sliced=True)
self.assertEqual(seqs.to_rich_dict(), copied.to_rich_dict())
self.assertNotEqual(id(copied), id(seqs))
for name in seqs.names:
self.assertNotEqual(id(copied.named_seqs[name]), copied.named_seqs[name])

def test_guess_input_type(self):
"""SequenceCollection _guess_input_type should figure out data type correctly"""
git = self.a._guess_input_type
Expand Down Expand Up @@ -2469,6 +2479,20 @@ def test_slice_align_info(self):
class AlignmentTests(AlignmentBaseTests, TestCase):
Class = Alignment

def test_sliced_deepcopy(self):
"""correctly deep copy aligned objects in an alignment"""
data = {"seq1": "ACGACGACG", "seq2": "ACGACGACG"}
orig = self.Class(data)
aln = orig[2:5]

notsliced = aln.deepcopy(sliced=False)
sliced = aln.deepcopy(sliced=True)
for name in orig.names:
self.assertEqual(len(notsliced.named_seqs[name].data), len(orig.named_seqs[name].data))
self.assertLessThan(len(sliced.named_seqs[name].data), len(orig.named_seqs[name].data))
self.assertEqual(notsliced.named_seqs[name].map.parent_length, len(orig))
self.assertEqual(sliced.named_seqs[name].map.parent_length, 3)

def test_sliding_windows(self):
"""sliding_windows should return slices of alignments."""
alignment = self.Class(
Expand Down

0 comments on commit b5a8c2c

Please sign in to comment.