Skip to content

Commit

Permalink
only care about overlaps in the maps
Browse files Browse the repository at this point in the history
  • Loading branch information
sroet committed Oct 16, 2020
1 parent 41c07c3 commit 46aa736
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 120 deletions.
12 changes: 7 additions & 5 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored):
# Make tuple for efficient lookupt
all_atoms_set = set(query).union(set(haystack))
self._all_atoms = tuple(sorted(list(all_atoms_set)))

self._all_residues = _residue_idx_for_atom(self._topology,
all_atoms_set)
self._use_atom_slice = self._set_atom_slice(self._all_atoms)
has_indexer = getattr(self, 'indexer', None) is not None
if not has_indexer:
Expand Down Expand Up @@ -215,6 +216,8 @@ def to_dict(self):
'haystack': list([int(val) for val in self._haystack]),
'all_atoms': tuple(
[int(val) for val in self._all_atoms]),
'all_residues': tuple(
[int(val) for val in self._all_residues]),
'n_neighbors_ignored': self._n_neighbors_ignored,
'atom_contacts':
self._serialize_contact_counter(self._atom_contacts),
Expand Down Expand Up @@ -245,6 +248,7 @@ def from_dict(cls, dct):
'query': deserialize_set,
'haystack': deserialize_set,
'all_atoms': deserialize_set,
'all_residues': deserialize_set,
'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct
}
for key in deserialization_helpers:
Expand Down Expand Up @@ -820,10 +824,8 @@ def __init__(self, positive, negative):
negative)
self._all_atoms_intersect = set(
positive._all_atoms).intersection(negative._all_atoms)
self._all_residues_intersect = _residue_idx_for_atom(
topology,
self._all_atoms_intersect
)
self._all_residues_intersect = set(
positive._all_residues).intersection(negative._all_residues)
super(ContactDifference, self).__init__(topology,
query,
haystack,
Expand Down
37 changes: 23 additions & 14 deletions contact_map/tests/test_contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,25 @@ def test_non_important_residues(self):
assert diff.atom_contacts is not None
assert diff.residue_contacts is not None

def test_truncated_diffs_residues(self):
ttraj = ContactFrequency(traj[0:4], cutoff=0.075,
n_neighbors_ignored=0)
traj_trunc = traj.atom_slice(range(6))
ttraj_trunc = ContactFrequency(traj_trunc[0:4], cutoff=0.075,
n_neighbors_ignored=0)
frame_trunc = ContactFrequency(traj_trunc[4], cutoff=0.075,
n_neighbors_ignored=0)
# The diff should only be made for atoms and residues in both maps
diff1 = ttraj - frame_trunc
diff2 = ttraj_trunc - frame_trunc
# Make sure the topology does not fail here
assert diff1.atom_contacts is not None
assert diff1.residue_contacts is not None

# Make sure the diffs are equal
assert diff1.atom_contacts.counter == diff2.atom_contacts.counter
assert diff1.residue_contacts.counter == diff2.residue_contacts.counter

def test_residue_rename_gives_different_atoms(self):
ttraj = ContactFrequency(traj[0:4], cutoff=0.075,
n_neighbors_ignored=0)
Expand Down Expand Up @@ -881,16 +900,7 @@ def test_broken_residues(self):
chain = frame.topology.chain(0)
_ = frame.topology.add_residue(name='test',
chain=chain)
res = frame.topology.residue(0)

frame.topology.atom(9).residue = res
frame.topology.atom(8).residue = res

assert frame.topology.atom(9).residue.index == 0
assert frame.topology.atom(8).residue.index == 0
assert ttraj.topology.atom(9).residue.index == 4
assert ttraj.topology.atom(8).residue.index == 4

frame.topology.residue(0).resSeq = "test"

with pytest.raises(RuntimeError) as e:
diff = ttraj - frame
Expand All @@ -902,6 +912,7 @@ def test_broken_residues_missing(self):
frame = ContactFrequency(traj[4], cutoff=0.075,
n_neighbors_ignored=0)

# Grab atom info to add 'equal' atoms at the end
element8 = frame.topology.atom(8).element
element9 = frame.topology.atom(9).element

Expand All @@ -912,7 +923,7 @@ def test_broken_residues_missing(self):
serial9 = frame.topology.atom(9).serial


# subset the topology
# Remove residue from the internals of the topology
frame._topology = frame.topology.subset(range(8))
assert frame.topology.n_residues == 4
# Add the original atoms again
Expand Down Expand Up @@ -1101,9 +1112,7 @@ def test_still_broken_atoms_and_residues_missing(self):
frame = ContactFrequency(traj[4], cutoff=0.075,
n_neighbors_ignored=0)

# Now we are going to delete an atom; this needs to be the last atom
# otherwise we trigger on a short-circuit
frame._topology = frame.topology.subset(range(2))
frame.topology.delete_atom_by_index(9)

# Make sure still break
with pytest.raises(RuntimeError) as e:
Expand Down
74 changes: 44 additions & 30 deletions contact_map/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,26 @@ def atoms_eq(one, other):
return all(checks)


def check_residues_ok(top0, top1, atoms, out_topology=None):
def residue_eq(one, other):
"""Check if residues are equal except from name and chain"""
checks = [one.index == other.index,
one.resSeq == other.resSeq,
one.segment_id == other.segment_id]
return all(checks)


def check_residues_ok(top0, top1, residues, out_topology=None):
"""Check if the residues in two topologies are equal.
If an out_topology is given, and the residues only differ in name, that
residue name will be updated inplace in the out_topology
"""
res_idx = _get_residue_indices(top0, top1, atoms)
res_idx = _get_residue_indices(top0, top1, residues)

all_res_ok = (bool(len(res_idx))) # True if is bigger than 0

all_res_ok &= not _count_mismatching_names(top0, top1, res_idx,
out_topology)
all_res_ok &= not _count_mismatching_residues(top0, top1, res_idx,
out_topology)
return all_res_ok


Expand All @@ -43,43 +51,44 @@ def _get_all_possible_residue_indices(top, atoms):
try:
out.append(top.atom(i).residue.index)
except IndexError:
break
pass
return set(out)


def _get_residue_indices(top0, top1, atoms):
"""Get the residue indices or an empty list if not equal."""
out_idx = {}
res_idx0 = _get_all_possible_residue_indices(top0, atoms)
res_idx1 = _get_all_possible_residue_indices(top1, atoms)

# Check if the involved indices are equal
if res_idx0 == res_idx1:
out_idx = res_idx0
return out_idx
def _get_residue_indices(top0, top1, residues):
"""Get the residue indices or an empty list if not able."""
for top in (top0, top1):
try:
res = [top.residue(i).index for i in residues]
except IndexError:
return {}
return res


def _count_mismatching_names(top0, top1, residx, out_topology=None):
def _count_mismatching_residues(top0, top1, residx, out_topology=None):
"""Check for mismatching names.
This will return truthy value if found and not fixable.
It also assumes all indices are present in both topologies
"""
# Check if the names are different
mismatched_idx = []
mismatched_other = []
for idx in residx:
name0 = top0.residue(idx).name
name1 = top1.residue(idx).name
if name0 != name1:
mismatched_idx.append((idx, name0, name1))
res0 = top0.residue(idx)
res1 = top1.residue(idx)
if not residue_eq(res0, res1):
mismatched_other.append(idx)
elif res0.name != res1.name:
mismatched_idx.append((idx, res0.name, res1.name))

if out_topology:
_fix_topology(mismatched_idx, out_topology)
_fix_residue_names(mismatched_idx, out_topology)
mismatched_idx = []
return len(mismatched_idx)
return len(mismatched_idx)+len(mismatched_other)


def _fix_topology(mismatched_idx, out_topology):
def _fix_residue_names(mismatched_idx, out_topology):
"""Fix the topology, assumes all indices are present"""
for idx, name0, name1 in mismatched_idx:
out_topology.residue(idx).name = "/".join([name0, name1])
Expand All @@ -106,20 +115,25 @@ def check_topologies(map0, map1, override_topology):
top0 = map0.topology
top1 = map1.topology

# Figure out the overlapping atoms
all_atoms0 = set(map0._all_atoms)
all_atoms1 = set(map1._all_atoms)

# Figure out overlapping residues
all_residues0 = set(map0._all_residues)
all_residues1 = set(map1._all_residues)

# This is intersect (for difference)
overlap_atoms = all_atoms0.intersection(all_atoms1)
overlap_residues = all_residues0.intersection(all_residues1)
top0, top1, topology = _get_default_topologies(top0, top1,
override_topology)

if override_topology:
override_topology = topology

# Figure out the overlapping atoms
all_atoms0 = set(map0._all_atoms)
all_atoms1 = set(map1._all_atoms)

# This is intersect (for difference)
overlap_atoms = all_atoms0 & all_atoms1
all_atoms_ok = check_atoms_ok(top0, top1, overlap_atoms)
all_res_ok = check_residues_ok(top0, top1, overlap_atoms,
all_res_ok = check_residues_ok(top0, top1, overlap_residues,
override_topology)
if not all_res_ok and not all_atoms_ok:
topology = md.Topology()
Expand Down

0 comments on commit 46aa736

Please sign in to comment.