Skip to content

Commit

Permalink
Merge pull request #94 from sroet/autozoom_plot
Browse files Browse the repository at this point in the history
[non-public API-break] Autozoom plots
  • Loading branch information
dwhswenson committed May 4, 2021
2 parents c9d1451 + 646f10b commit aa0c46e
Show file tree
Hide file tree
Showing 13 changed files with 366 additions and 203 deletions.
51 changes: 37 additions & 14 deletions contact_map/contact_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
import warnings
from .plot_utils import ranged_colorbar
from .plot_utils import ranged_colorbar, make_x_y_ranges

# matplotlib is technically optional, but required for plotting
try:
Expand Down Expand Up @@ -57,6 +57,14 @@ def _patch_from_spmatrix(cls, data): # -no-cov-
pd.core.arrays.SparseArray.from_spmatrix = classmethod(_patch_from_spmatrix)
# TODO: this is the end of what to remove when pandas is fixed


def _get_total_counter_range(counter):
numbers = [i for key in counter.keys() for i in key]
if len(numbers) == 0:
return (0, 0)
return (min(numbers), max(numbers)+1)


class ContactCount(object):
"""Return object when dealing with contacts (residue or atom).
Expand All @@ -83,16 +91,27 @@ class ContactCount(object):
method to obtain the object associated with the number used in
``counter``; typically :meth:`mdtraj.Topology.residue` or
:meth:`mdtraj.Topology.atom`.
n_x : int
number of objects in the x direction (used in plotting)
n_y : int
number of objects in the y direction (used in plotting)
n_x : int, tuple(start, end), optional
range of objects in the x direction (used in plotting)
Default tries to plot the least amount of symetric points.
n_y : int, tuple(start, end), optional
range of objects in the y direction (used in plotting)
Default tries to show the least amount of symetric points.
max_size : int, optional
maximum size of the count
(used to determine the shape of output matrices and dataframes)
"""
def __init__(self, counter, object_f, n_x, n_y):
def __init__(self, counter, object_f, n_x=None, n_y=None, max_size=None):
self._counter = counter
self._object_f = object_f
self.n_x = n_x
self.n_y = n_y
self.total_range = _get_total_counter_range(counter)
self.n_x, self.n_y = make_x_y_ranges(n_x, n_y, counter)
if max_size is None:
self.max_size = max([self.total_range[-1],
self.n_x.max,
self.n_y.max])
else:
self.max_size = max_size

@property
def counter(self):
Expand All @@ -111,7 +130,8 @@ def sparse_matrix(self):
Rows/columns correspond to indices and the values correspond to
the count
"""
mtx = scipy.sparse.dok_matrix((self.n_x, self.n_y))
max_size = self.max_size
mtx = scipy.sparse.dok_matrix((max_size, max_size))
for (k, v) in self._counter.items():
key = list(k)
mtx[key[0], key[1]] = v
Expand All @@ -128,8 +148,8 @@ def df(self):
the count
"""
mtx = self.sparse_matrix
index = list(range(self.n_x))
columns = list(range(self.n_y))
index = list(range(self.max_size))
columns = list(range(self.max_size))

if _PD_VERSION < (0, 25): # py27 only -no-cov-
mtx = mtx.tocoo()
Expand Down Expand Up @@ -164,15 +184,18 @@ def _check_number_of_pixels(self, figure):
ypixels = dpi*figheight

# Check if every value has a pixel
if xpixels/self.n_x < 1 or ypixels/self.n_y < 1:
if (xpixels/self.n_x.range_length < 1 or
ypixels/self.n_y.range_length < 1):
msg = ("The number of pixels in the figure is insufficient to show"
" all the contacts.\n Please save this as a vector image "
"(such as a PDF) to view the correct result.\n Another "
"option is to increase the 'dpi' (currently: "+str(dpi)+"),"
" or the 'figsize' (currently: " + str((figwidth,
figheight)) +
").\n Recommended minimum amount of pixels = "
+ str((self.n_x, self.n_y))+" (width, height).")
+ str((self.n_x.range_length,
self.n_y.range_length))
+ " (width, height).")
warnings.warn(msg, RuntimeWarning)

def plot(self, cmap='seismic', vmin=-1.0, vmax=1.0, with_colorbar=True,
Expand Down Expand Up @@ -230,7 +253,7 @@ def plot_axes(self, ax, cmap='seismic', vmin=-1.0, vmax=1.0,

norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
cmap_f = plt.get_cmap(cmap)
ax.axis([0, self.n_x, 0, self.n_y])
ax.axis([self.n_x.min, self.n_x.max, self.n_y.min, self.n_y.max])
ax.set_facecolor(cmap_f(norm(0.0)))

min_val = 0.0
Expand Down
47 changes: 29 additions & 18 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,9 @@ def _residue_idx_for_atom(topology, atom_list):
return set([topology.atom(a).residue.index for a in atom_list])


def _range_from_object_list(object_list):
"""
Objects must have .index attribute (e.g., MDTraj Residue/Atom)
"""
idxs = [obj.index for obj in object_list]
return (min(idxs), max(idxs) + 1)
def _range_from_iterable(iterable):
sort = sorted(iterable)
return (sort[0], sort[-1]+1)


class ContactsDict(object):
Expand Down Expand Up @@ -134,6 +131,8 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored):
self._query = set(query)
self._haystack = set(haystack)

self._query_res_idx = set(_residue_idx_for_atom(topology, query))
self._haystack_res_idx = set(_residue_idx_for_atom(topology, haystack))
# Make tuple for efficient lookupt
all_atoms_set = set(query).union(set(haystack))
self._all_atoms = tuple(sorted(list(all_atoms_set)))
Expand Down Expand Up @@ -214,6 +213,10 @@ def to_dict(self):
'cutoff': self._cutoff,
'query': list([int(val) for val in self._query]),
'haystack': list([int(val) for val in self._haystack]),
'query_res_idx': list([int(val) for val
in self._query_res_idx]),
'haystack_res_idx': list([int(val) for val in
self._haystack_res_idx]),
'all_atoms': tuple(
[int(val) for val in self._all_atoms]),
'all_residues': tuple(
Expand Down Expand Up @@ -247,6 +250,8 @@ def from_dict(cls, dct):
'residue_contacts': cls._deserialize_contact_counter,
'query': deserialize_set,
'haystack': deserialize_set,
'query_res_idx': deserialize_set,
'haystack_res_idx': deserialize_set,
'all_atoms': deserialize_set,
'all_residues': deserialize_set,
'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct
Expand Down Expand Up @@ -464,15 +469,25 @@ def query_residues(self):
"""list : residues for atoms in the query"""
return _residue_for_atom(self.topology, self.query)

@property
def query_range(self):
"""return an tuple with the (min, max+1) of query"""
return _range_from_iterable(self.query)

@property
def haystack_range(self):
"""return an tuple with the (min, max+1) of haystack"""
return _range_from_iterable(self.haystack)

@property
def haystack_residue_range(self):
"""(int, int): min and (max + 1) of haystack residue indices"""
return _range_from_object_list(self.haystack_residues)
return _range_from_iterable(self._haystack_res_idx)

@property
def query_residue_range(self):
"""(int, int): min and (max + 1) of query residue indices"""
return _range_from_object_list(self.query_residues)
return _range_from_iterable(self._query_res_idx)

def most_common_atoms_for_residue(self, residue):
"""
Expand Down Expand Up @@ -768,22 +783,20 @@ def subtract_contact_frequency(self, other):
@property
def atom_contacts(self):
"""Atoms pairs mapped to fraction of trajectory with that contact"""
n_x = self.topology.n_atoms
n_y = self.topology.n_atoms
return ContactCount(collections.Counter({
item[0]: float(item[1])/self.n_frames
for item in self._atom_contacts.items()
}), self.topology.atom, n_x, n_y)
}), self.topology.atom, self.query_range, self.haystack_range,
self.topology.n_atoms)

@property
def residue_contacts(self):
"""Residue pairs mapped to fraction of trajectory with that contact"""
n_x = self.topology.n_residues
n_y = self.topology.n_residues
return ContactCount(collections.Counter({
item[0]: float(item[1])/self.n_frames
for item in self._residue_contacts.items()
}), self.topology.residue, n_x, n_y)
}), self.topology.residue, self.query_residue_range,
self.haystack_residue_range, self.topology.n_residues)


class ContactDifference(ContactObject):
Expand Down Expand Up @@ -880,17 +893,15 @@ def atom_contacts(self):
neg_count=self.negative.atom_contacts,
selection=self._all_atoms_intersect,
object_f=self.topology.atom,
n_x=self.topology.n_atoms,
n_y=self.topology.n_atoms)
max_size=self.topology.n_atoms)

@property
def residue_contacts(self):
return self._get_filtered_sub(pos_count=self.positive.residue_contacts,
neg_count=self.negative.residue_contacts,
selection=self._all_residues_intersect,
object_f=self.topology.residue,
n_x=self.topology.n_residues,
n_y=self.topology.n_residues)
max_size=self.topology.n_residues)

def _get_filtered_sub(self, pos_count, neg_count, selection, *args,
**kwargs):
Expand Down
75 changes: 75 additions & 0 deletions contact_map/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import collections

try: # try loop for testing
import matplotlib
Expand Down Expand Up @@ -51,3 +52,77 @@ def ranged_colorbar(cmap, norm, cbmin, cbmax, ax=None):
sm._A = []
cb = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
return cb


def _int_or_range_to_tuple(posible_int):
if isinstance(posible_int, collections.abc.Iterable):
return (posible_int[0], posible_int[1])
else:
return (0, posible_int)


def _get_low_high_counter_range(counter):
"""Give the (min, max + 1) for both the low and high keys in counter"""
keys = [tuple(sorted(list(i))) for i in counter.keys()]
if len(keys) == 0:
return (0, 0), (0, 0)
lows, highs = zip(*keys)
return (min(lows), max(lows)+1), (min(highs), max(highs)+1)


def _get_sorted_counter_range(counter):
"""Return smallest range, longest range for the low and high counter"""
low, high = _get_low_high_counter_range(counter)
if low[1]-low[0] > high[-1]-high[0]:
return high, low
else:
return low, high


def _sanitize_n_x_n_y(n_x, n_y, counter):
if n_x is None and n_y is None:
n_x, n_y = _get_sorted_counter_range(counter)
elif n_x is None or n_y is None:
raise ValueError("Either both n_x and n_y need to be defined or "
"neither.")
if isinstance(n_x, _ContactPlotRange):
n_x = n_x.n
if isinstance(n_y, _ContactPlotRange):
n_y = n_y.n
return n_x, n_y


def make_x_y_ranges(n_x, n_y, counter):
"""Return ContactPlotRange for both x and y"""
n_x, n_y = _sanitize_n_x_n_y(n_x, n_y, counter)
n_x = _ContactPlotRange(n_x)
n_y = _ContactPlotRange(n_y)
return n_x, n_y


class _ContactPlotRange(object):
"""Object that deals with functions that are needed for plot ranges
Parameters
----------
n : int, tuple(start, end)
range of objects in the given direction (used in plotting)
"""
def __init__(self, n):
self.n = n
self.min, self.max = _int_or_range_to_tuple(n)

@property
def range_length(self):
return self.max - self.min

def __eq__(self, other):
if isinstance(other, (int, tuple)):
return self.n == other
elif isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
else:
return False

def __ne__(self, other):
return not self.__eq__(other)
41 changes: 37 additions & 4 deletions contact_map/tests/test_contact_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ def test_pixel_warning(self):

def test_initialization(self):
assert self.atom_contacts._object_f == self.topology.atom
assert self.atom_contacts.n_x == self.topology.n_atoms
assert self.atom_contacts.n_y == self.topology.n_atoms
assert self.atom_contacts.n_x == self.map.query_range
assert self.atom_contacts.n_y == self.map.haystack_range
assert self.atom_contacts.max_size == self.topology.n_atoms
assert self.residue_contacts._object_f == self.topology.residue
assert self.residue_contacts.n_x == self.topology.n_residues
assert self.residue_contacts.n_y == self.topology.n_residues
assert self.residue_contacts.n_x == self.map.query_residue_range
assert self.residue_contacts.n_y == self.map.haystack_residue_range
assert self.residue_contacts.max_size == self.topology.n_residues

def test_sparse_matrix(self):
assert_array_equal(self.map.atom_contacts.sparse_matrix.todense(),
Expand Down Expand Up @@ -171,3 +173,34 @@ def test_most_common_idx(self, obj_type):
expected_count = [(ll[0], float(ll[1]) / 5.0)
for ll in source_expected.items()]
assert set(contacts.most_common_idx()) == set(expected_count)

def test_n_x_smaller_than_n_y_default(self):
# Make a map that has a bigger range of for low numbers
ac0 = ContactFrequency(traj, cutoff=0.075,
n_neighbors_ignored=0,
query=[0, 4],
haystack=[7, 8, 9]).atom_contacts
# Also make a map that has a bigger range of high numbers
ac1 = ContactFrequency(traj, cutoff=0.075,
n_neighbors_ignored=0,
query=[0, 1],
haystack=[5, 8, 9]).atom_contacts
default0 = ContactCount(ac0._counter, ac0._object_f)
default1 = ContactCount(ac1._counter, ac1._object_f)
assert default0.n_x == (7, 9 + 1) # n_x should be shorter: here 3
assert default0.n_y == (0, 4 + 1) # n_y should be longer: here 5

assert default1.n_x == (0, 1 + 1) # n_x should be shorter: here 2
assert default1.n_y == (5, 9 + 1) # n_y should be longer: here 5

@pytest.mark.parametrize("keyword", ["n_x", "n_y"])
def test_raise_on_only_n_x_or_ny(self, keyword):
ac = self.map.atom_contacts
kwargs = {keyword: "test"}
with pytest.raises(ValueError) as e:
ContactCount(counter=ac._counter, object_f=ac._object_f, **kwargs)
assert keyword in str(e.value)

def test_empty_counter(self):
# Just a smoke test, this should not error out
ContactCount(dict(), None)
1 change: 1 addition & 0 deletions contact_map/tests/test_contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@ def test_broken_atoms_and_residues_missing(self):

# Make sure this now works
diff = OverrideTopologyContactDifference(ttraj, frame, ttraj.topology)

assert diff.residue_contacts is not None
assert diff.atom_contacts is not None
assert diff.topology == ttraj.topology
Expand Down

0 comments on commit aa0c46e

Please sign in to comment.