Skip to content

Commit

Permalink
Merge branch 'master' of github.com:dwhswenson/contact_map into api_c…
Browse files Browse the repository at this point in the history
…leanup
  • Loading branch information
dwhswenson committed Oct 21, 2020
2 parents a05fe7a + 350a28e commit 929fda7
Show file tree
Hide file tree
Showing 9 changed files with 13,804 additions and 35 deletions.
4 changes: 3 additions & 1 deletion contact_map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
__version__ = version.version

from .contact_map import (
ContactMap, ContactFrequency, ContactDifference
ContactMap, ContactFrequency, ContactDifference,
AtomMismatchedContactDifference, ResidueMismatchedContactDifference,
OverrideTopologyContactDifference
)

from .contact_count import ContactCount
Expand Down
13 changes: 13 additions & 0 deletions contact_map/contact_count.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import scipy
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -310,3 +311,15 @@ def most_common_idx(self):
most_common : same thing, using objects as key
"""
return self._counter.most_common()

def filter(self, idx):
"""New ContactCount filtered to idx.
Returns a new ContactCount with the only the counter keys/values
where both the keys are in idx
"""
dct = {k: v for k, v in self._counter.items()
if all([i in idx for i in k])}
new_count = collections.Counter()
new_count.update(dct)
return ContactCount(new_count, self._object_f, self.n_x, self.n_y)
172 changes: 139 additions & 33 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .contact_count import ContactCount
from .atom_indexer import AtomSlicedIndexer, IdentityIndexer
from .py_2_3 import inspect_method_arguments
from .fix_parameters import ParameterFixer

# TODO:
# * switch to something where you can define the haystack -- the trick is to
Expand All @@ -25,6 +26,7 @@
# query atom. Doesn't look like anything is doing that now: neighbors
# doesn't use voxels, neighborlist doesn't limit the haystack


def _residue_and_index(residue, topology):
res = residue
try:
Expand Down Expand Up @@ -57,11 +59,14 @@ def residue_neighborhood(residue, n=1):
return [idx for idx in neighborhood if idx in chain]



def _residue_for_atom(topology, atom_list):
return set([topology.atom(a).residue for a in atom_list])


def _residue_idx_for_atom(topology, atom_list):
return set([topology.atom(a).residue.index for a in atom_list])


def _range_from_object_list(object_list):
"""
Objects must have .index attribute (e.g., MDTraj Residue/Atom)
Expand Down Expand Up @@ -129,11 +134,11 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored):
self._query = set(query)
self._haystack = set(haystack)


# Make tuple for efficient lookupt
all_atoms_set = set(query).union(set(haystack))
self._all_atoms = tuple(sorted(list(all_atoms_set)))

self._all_residues = _residue_idx_for_atom(self._topology,
all_atoms_set)
self._use_atom_slice = self._set_atom_slice(self._all_atoms)
has_indexer = getattr(self, 'indexer', None) is not None
if not has_indexer:
Expand Down Expand Up @@ -211,10 +216,12 @@ def to_dict(self):
'haystack': list([int(val) for val in self._haystack]),
'all_atoms': tuple(
[int(val) for val in self._all_atoms]),
'all_residues': tuple(
[int(val) for val in self._all_residues]),
'n_neighbors_ignored': self._n_neighbors_ignored,
'atom_contacts': \
'atom_contacts':
self._serialize_contact_counter(self._atom_contacts),
'residue_contacts': \
'residue_contacts':
self._serialize_contact_counter(self._residue_contacts),
'use_atom_slice': self._use_atom_slice}
return dct
Expand All @@ -241,6 +248,7 @@ def from_dict(cls, dct):
'query': deserialize_set,
'haystack': deserialize_set,
'all_atoms': deserialize_set,
'all_residues': deserialize_set,
'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct
}
for key in deserialization_helpers:
Expand Down Expand Up @@ -322,19 +330,22 @@ def _check_compatibility(self, other, err=AssertionError):
compatibility_attrs = ['cutoff', 'topology', 'query', 'haystack',
'n_neighbors_ignored']
failed_attr = {}
err_msg = ""
for attr in compatibility_attrs:
self_val = getattr(self, attr)
other_val = getattr(other, attr)
if self_val != other_val:
failed_attr.update({attr: (self_val, other_val)})
failed_attr[attr] = (self_val, other_val)
err_msg += " {attr}: {self} != {other}\n".format(
attr=attr, self=str(self_val), other=str(other_val)
)

msg = "Incompatible ContactObjects:\n"
for (attr, vals) in failed_attr.items():
msg += " {attr}: {self} != {other}\n".format(
attr=attr, self=str(vals[0]), other=str(vals[1])
)
if failed_attr:
msg += err_msg
if failed_attr and err is not None:
raise err(msg)
return True
else:
return failed_attr

def save_to_file(self, filename, mode="w"):
"""Save this object to the given file.
Expand Down Expand Up @@ -782,27 +793,41 @@ def residue_contacts(self):
class ContactDifference(ContactObject):
"""
Contact map comparison (atomic and residue).
This can compare single frames or entire trajectories (or even mix the
two!) While this can be directly instantiated by the user, the more
common way to make this object is by using the ``-`` operator, i.e.,
``diff = map_1 - map_2``.
"""
# Some class variables on how we handle mismatches, mainly for subclassing.
_allow_mismatched_atoms = False
_allow_mismatched_residues = False
_override_topology = True

def __init__(self, positive, negative):
self.positive = positive
self.negative = negative
positive._check_compatibility(negative)
super(ContactDifference, self).__init__(positive.topology,
positive.query,
positive.haystack,
positive.cutoff,
positive.n_neighbors_ignored)
fix_parameters = ParameterFixer(
allow_mismatched_atoms=self._allow_mismatched_atoms,
allow_mismatched_residues=self._allow_mismatched_residues,
override_topology=self._override_topology)

(topology, query,
haystack, cutoff,
n_neighbors_ignored) = fix_parameters.get_parameters(positive,
negative)
self._all_atoms_intersect = set(
positive._all_atoms).intersection(negative._all_atoms)
self._all_residues_intersect = set(
positive._all_residues).intersection(negative._all_residues)
super(ContactDifference, self).__init__(topology,
query,
haystack,
cutoff,
n_neighbors_ignored)

def to_dict(self):
"""Convert object to a dict.
Keys should be strings; values should be (JSON-) serializable.
See also
--------
from_dict
Expand All @@ -817,12 +842,10 @@ def to_dict(self):
@classmethod
def from_dict(cls, dct):
"""Create object from dict.
Parameters
----------
dct : dict
dict-formatted serialization (see to_dict for details)
See also
--------
to_dict
Expand Down Expand Up @@ -857,16 +880,99 @@ def from_contacts(self, *args, **kwargs): #pylint: disable=W0221

@property
def atom_contacts(self):
n_x = self.topology.n_atoms
n_y = self.topology.n_atoms
diff = collections.Counter(self.positive.atom_contacts.counter)
diff.subtract(self.negative.atom_contacts.counter)
return ContactCount(diff, self.topology.atom, n_x, n_y)
return self._get_filtered_sub(pos_count=self.positive.atom_contacts,
neg_count=self.negative.atom_contacts,
selection=self._all_atoms_intersect,
object_f=self.topology.atom,
n_x=self.topology.n_atoms,
n_y=self.topology.n_atoms)

@property
def residue_contacts(self):
n_x = self.topology.n_residues
n_y = self.topology.n_residues
diff = collections.Counter(self.positive.residue_contacts.counter)
diff.subtract(self.negative.residue_contacts.counter)
return ContactCount(diff, self.topology.residue, n_x, n_y)
return self._get_filtered_sub(pos_count=self.positive.residue_contacts,
neg_count=self.negative.residue_contacts,
selection=self._all_residues_intersect,
object_f=self.topology.residue,
n_x=self.topology.n_residues,
n_y=self.topology.n_residues)

def _get_filtered_sub(self, pos_count, neg_count, selection, *args,
**kwargs):
"""Get a filtered subtraction between two ContactCounts"""
filtered_pos = pos_count.filter(selection)
filtered_neg = neg_count.filter(selection)
diff = collections.Counter(filtered_pos.counter)
diff.subtract(filtered_neg.counter)
return ContactCount(diff, *args, **kwargs)


class AtomMismatchedContactDifference(ContactDifference):
"""
Contact map comparison (only residues).
"""
_allow_mismatched_atoms = True

def most_common_atoms_for_contact(self, *args, **kwargs):
self._missing_atom_contacts()

def most_common_atoms_for_residue(self, *args, **kwargs):
self._missing_atom_contacts()

@property
def atom_contacts(self):
self._missing_atom_contacts()

@property
def haystack_residues(self):
self._missing_atom_contacts()

@property
def query_residues(self):
self._missing_atom_contacts()

def _missing_atom_contacts(self):
raise RuntimeError("Different atom indices involved between the two"
" maps, so this does not make sense.")


class ResidueMismatchedContactDifference(ContactDifference):
"""
Contact map comparison (only atoms).
"""
_allow_mismatched_residues = True

@property
def residue_contacts(self):
self._missing_residue_contacts()

@property
def _residue_ignore_atom_idxs(self):
self._missing_residue_contacts()

def most_common_atoms_for_contact(self, *args, **kwargs):
self._missing_residue_contacts()

def most_common_atoms_for_residue(self, *args, **kwargs):
self._missing_residue_contacts()

@property
def haystack_residues(self):
self._missing_residue_contacts()

@property
def query_residues(self):
self._missing_residue_contacts()

def _missing_residue_contacts(self):
raise RuntimeError("Different residue indices involved between the two"
" maps, so this does not make sense.")


class OverrideTopologyContactDifference(ContactDifference):
"""
Contact map comparison with a user provided Topology.
"""

def __init__(self, positive, negative, topology):
self._override_topology = topology
super().__init__(positive, negative)

0 comments on commit 929fda7

Please sign in to comment.