Skip to content

Commit

Permalink
Merge pull request #101 from sroet/daskcontacttrajectory
Browse files Browse the repository at this point in the history
DaskContactTrajectory
  • Loading branch information
dwhswenson committed Apr 11, 2021
2 parents fbcfdfd + 9d56908 commit eb0c208
Show file tree
Hide file tree
Showing 9 changed files with 1,210 additions and 900 deletions.
2 changes: 1 addition & 1 deletion contact_map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
ConcurrencePlotter, plot_concurrence
)

from .dask_runner import DaskContactFrequency
from .dask_runner import DaskContactFrequency, DaskContactTrajectory

from . import plot_utils
70 changes: 45 additions & 25 deletions contact_map/contact_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
from collections import abc, Counter

from .contact_map import ContactFrequency, ContactObject
import json


# Split this out of the to prevent code duplication for DaskContactTrajectory
def _build_contacts(contact_object, trajectory):
"""Make a contact map for every frame in trajectory.
Parameters
----------
contact_object : `ContactObject`
The contact object that will be used to make the contact maps.
trajectory: `mdtraj.Trajectory`
The trajectory for which we will return a contactObject for each frame.
Returns
-------
out : list of tuples
list of (atom_contacts, residue_contacts) for each frame in trajectory.
"""

# atom_contacts, residue_contacts = self._empty_contacts()
atom_contacts = []
residue_contacts = []
residue_ignore_atom_idxs = contact_object._residue_ignore_atom_idxs
residue_query_atom_idxs = contact_object.indexer.residue_query_atom_idxs
used_trajectory = contact_object.indexer.slice_trajectory(trajectory)

# range(len(trajectory)) avoids recopying topology, as would occur
# in `for frame in trajectory`
for frame_num in range(len(trajectory)):
frame_contacts = contact_object._contact_map(used_trajectory,
frame_num,
residue_query_atom_idxs,
residue_ignore_atom_idxs)
frame_atom_contacts, frame_residue_contacts = frame_contacts
frame_atom_contacts = \
contact_object.indexer.convert_atom_contacts(frame_atom_contacts)
# TODO unify contact building with something like this?
# atom_contacts, residue_contact = self._update_contacts(...)
atom_contacts.append(frame_atom_contacts)
residue_contacts.append(frame_residue_contacts)
return zip(atom_contacts, residue_contacts)


class ContactTrajectory(ContactObject, abc.Sequence):
"""Track all the contacts over a trajectory, frame-by-frame.
Expand Down Expand Up @@ -44,7 +85,7 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45,
n_frames=1,
indexer=self.indexer
)
for atom_contacts, residue_contacts in zip(*contacts)
for atom_contacts, residue_contacts in contacts
]

def __getitem__(self, num):
Expand Down Expand Up @@ -83,28 +124,7 @@ def from_contacts(cls, atom_contacts, residue_contacts, topology,
return cls.from_contact_maps(contact_maps)

def _build_contacts(self, trajectory):
# atom_contacts, residue_contacts = self._empty_contacts()
atom_contacts = []
residue_contacts = []

residue_ignore_atom_idxs = self._residue_ignore_atom_idxs
residue_query_atom_idxs = self.indexer.residue_query_atom_idxs
used_trajectory = self.indexer.slice_trajectory(trajectory)

# range(len(trajectory)) avoids recopying topology, as would occur
# in `for frame in trajectory`
for frame_num in range(len(trajectory)):
frame_contacts = self._contact_map(used_trajectory, frame_num,
residue_query_atom_idxs,
residue_ignore_atom_idxs)
frame_atom_contacts, frame_residue_contacts = frame_contacts
frame_atom_contacts = \
self.indexer.convert_atom_contacts(frame_atom_contacts)
# TODO unify contact building with something like this?
# atom_contacts, residue_contact = self._update_contacts(...)
atom_contacts.append(frame_atom_contacts)
residue_contacts.append(frame_residue_contacts)
return atom_contacts, residue_contacts
return _build_contacts(self, trajectory)

def contact_frequency(self):
"""Create a :class:`.ContactFrequency` from this contact trajectory
Expand Down Expand Up @@ -297,7 +317,7 @@ def _normal(self):

def __next__(self):
# if self.max + self.step < self.width:
# to_add, to_sub = self._startup()
# to_add, to_sub = self._startup()
if self.max + self.step < self.length:
to_add, to_sub = self._normal()
else:
Expand Down
117 changes: 111 additions & 6 deletions contact_map/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import frequency_task
from .contact_map import ContactFrequency, ContactObject
from .contact_trajectory import ContactTrajectory
import mdtraj as md


Expand All @@ -27,19 +28,41 @@ def dask_run(trajectory, client, run_info):
:class:`.ContactFrequency` :
total contact frequency for the trajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))
subtrajs = dask_load_and_slice(trajectory, client, run_info)

subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
maps = client.map(frequency_task.map_task, subtrajs,
parameters=run_info['parameters'])
freq = client.submit(frequency_task.reduce_all_results, maps)

return freq.result()


def dask_load_and_slice(trajectory, client, run_info):
"""
Run dask to load a trajectory and make a list of subtraj futures that
can be used for the other analysis
Parameters
----------
trajectory : mdtraj.Trajectory
client : dask.distributed.Client
run_info : dict
Keys are 'trajectory_file' (trajectory filename) and 'load_kwargs'
(additional kwargs passed to md.load)
Returns
-------
list of Futures :
A list of futures for loaded subtrajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))
subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
return subtrajs


class DaskContactFrequency(ContactFrequency):
"""Dask-based parallelization of contact frequency.
Expand Down Expand Up @@ -84,7 +107,6 @@ def __init__(self, client, filename, query=None, haystack=None,
self.client = client
self.filename = filename
trajectory = md.load(filename, **kwargs)

self.kwargs = kwargs

super(DaskContactFrequency, self).__init__(
Expand All @@ -108,3 +130,86 @@ def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}


class DaskContactTrajectory(ContactTrajectory):
"""Dask-based parallelization of contact trajectory.
The contact trajectory tracks all contacts of a trajectory, frame-by-frame.
See :class:`.ContactTrajectory` for details. This implementation
parallelizes the contact map calculations using
``dask.distributed``, which must be installed separately to use this
object.
Notes
-----
The interface for this object closely mimics that of the
:class:`.ContactTrajectory` object, with the addition requiring the
``dask.distributed.Client`` as input. However, there is one important
difference. Whereas :class:`.ContactTrajectory` takes an
``mdtraj.Trajectory`` object as input, :class:`.DaskContactTrajectory`
takes a file name, plus any extra kwargs that MDTraj needs to load the
file.
Parameters
----------
client : dask.distributed.Client
Client object connected to the dask network.
filename : str
Name of the file where the trajectory is located. File must be
accessible by all workers in the dask network.
query : list of int
Indices of the atoms to be included as query. Default ``None``
means all atoms.
haystack : list of int
Indices of the atoms to be included as haystack. Default ``None``
means all atoms.
cutoff : float
Cutoff distance for contacts, in nanometers. Default 0.45.
n_neighbors_ignored : int
Number of neighboring residues (in the same chain) to ignore.
Default 2.
"""
def __init__(self, client, filename, query=None, haystack=None,
cutoff=0.45, n_neighbors_ignored=2, **kwargs):
self.client = client
self.filename = filename
self.kwargs = kwargs
self.trajectory = md.load(filename, **kwargs)
self.contact_object = ContactObject(
self.trajectory.topology,
query=query,
haystack=haystack,
cutoff=cutoff,
n_neighbors_ignored=n_neighbors_ignored,
)

super(DaskContactTrajectory, self).__init__(
self.trajectory, query, haystack, cutoff, n_neighbors_ignored,
)

def _build_contacts(self, trajectory):
subtrajs = dask_load_and_slice(self.trajectory, self.client,
self.run_info)
contact_lists = self.client.map(frequency_task.contacts_per_frame_task,
subtrajs,
contact_object=self.contact_object)
# Return a generator for this to work out
gen = ((atom_contacts, residue_contacts)
for contacts in contact_lists
for atom_contacts, residue_contacts in contacts.result())
return gen

@property
def parameters(self):
return {'query': self.query,
'haystack': self.haystack,
'cutoff': self.cutoff,
'n_neighbors_ignored': self.n_neighbors_ignored}

@property
def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}
19 changes: 19 additions & 0 deletions contact_map/frequency_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import mdtraj as md
from contact_map import ContactFrequency
from contact_map.contact_trajectory import _build_contacts

def block_slices(n_total, n_per_block):
"""Determine slices for splitting the input array.
Expand Down Expand Up @@ -86,6 +87,7 @@ def load_trajectory_task(subslice, file_name, **kwargs):
"""
return md.load(file_name, **kwargs)[subslice]


def map_task(subtrajectory, parameters):
"""Task to be mapped to all subtrajectories. Run ContactFrequency
Expand All @@ -103,6 +105,23 @@ def map_task(subtrajectory, parameters):
"""
return ContactFrequency(subtrajectory, **parameters)


def contacts_per_frame_task(trajectory, contact_object):
"""Task that will mimic ContactTrajectory._build_contacts, but with
a pre-initialized ContactObject instead of `self` to produce the contacts
Parameters
----------
trajectory : mdtraj.Trajectory
single trajectory segment to calculate contacts for every frame
contactobject : ContactObject
The instance that will replace self in _build_contacts
"""
return _build_contacts(contact_object, trajectory)


def reduce_all_results(contacts):
"""Combine multiple :class:`.ContactFrequency` objects into one
Expand Down
71 changes: 46 additions & 25 deletions contact_map/tests/test_dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from .utils import *
from contact_map.dask_runner import *
from contact_map import ContactFrequency
from contact_map import ContactFrequency, ContactTrajectory
from collections.abc import Iterable
import mdtraj


def dask_setup_test_cluster(distributed, n_workers=4, n_attempts=3):
"""Set up a test cluster using dask.distributed. Try up to n_attempts
Expand All @@ -25,45 +28,63 @@ def dask_setup_test_cluster(distributed, n_workers=4, n_attempts=3):
pytest.skip("Failed to set up distributed LocalCluster")


class TestDaskContactFrequency(object):
def test_dask_integration(self):
# this is an integration test to check that dask works
class TestDaskRunners(object):
def setup(self):
dask = pytest.importorskip('dask') # pylint: disable=W0612
distributed = pytest.importorskip('dask.distributed')
self.distributed = distributed
# Explicitly set only 4 workers on Travis instead of 31
# Fix copied from https://github.com/spencerahill/aospy/pull/220/files
cluster = dask_setup_test_cluster(distributed, n_workers=4)
client = distributed.Client(cluster)
filename = find_testfile("trajectory.pdb")
self.cluster = dask_setup_test_cluster(distributed, n_workers=4)
self.client = distributed.Client(self.cluster)
self.filename = find_testfile("trajectory.pdb")

dask_freq = DaskContactFrequency(client, filename, cutoff=0.075,
n_neighbors_ignored=0)
client.close()
assert dask_freq.n_frames == 5
def teardown(self):
self.client.shutdown()

def test_dask_atom_slice(self):
# This is an integration test to check that dask works with atom_slice
dask = pytest.importorskip('dask') # pylint: disable=W0612
distributed = pytest.importorskip('dask.distributed')
# Explicitly set only 4 workers on Travis instead of 31
# Fix copied from https://github.com/spencerahill/aospy/pull/220/files
cluster = dask_setup_test_cluster(distributed, n_workers=4)
client = distributed.Client(cluster)
filename = find_testfile("trajectory.pdb")
@pytest.mark.parametrize("dask_cls", [DaskContactFrequency,
DaskContactTrajectory])
def test_dask_integration(self, dask_cls):
dask_freq = dask_cls(self.client, self.filename, cutoff=0.075,
n_neighbors_ignored=0)
if isinstance(dask_freq, ContactFrequency):
assert dask_freq.n_frames == 5
elif isinstance(dask_freq, ContactTrajectory):
assert len(dask_freq) == 5

dask_freq0 = DaskContactFrequency(client, filename, query=[3, 4],
def test_dask_atom_slice(self):
dask_freq0 = DaskContactFrequency(self.client, self.filename,
query=[3, 4],
haystack=[6, 7], cutoff=0.075,
n_neighbors_ignored=0)
client.close()
self.client.close()
assert dask_freq0.n_frames == 5
client = distributed.Client(cluster)
self.client = self.distributed.Client(self.cluster)
# Set the slicing of contact frequency (used in the frqeuency task)
# to False
ContactFrequency._class_use_atom_slice = False
dask_freq1 = DaskContactFrequency(client, filename, query=[3, 4],
dask_freq1 = DaskContactFrequency(self.client,
self.filename, query=[3, 4],
haystack=[6, 7], cutoff=0.075,
n_neighbors_ignored=0)
client.close()
assert dask_freq0._use_atom_slice is True
assert dask_freq1._use_atom_slice is False
assert dask_freq0 == dask_freq1

@pytest.mark.parametrize("dask_cls, norm_cls",[
(DaskContactFrequency, ContactFrequency),
(DaskContactTrajectory, ContactTrajectory)])
def test_answer_equal(self, dask_cls, norm_cls):
trj = mdtraj.load(self.filename)
dask_result = dask_cls(self.client, self.filename)
norm_result = norm_cls(trj)
if isinstance(dask_result, Iterable):
for i, j in zip(dask_result, norm_result):
assert i.atom_contacts._counter == j.atom_contacts._counter
assert (i.residue_contacts._counter ==
j.residue_contacts._counter)
else:
assert (dask_result.atom_contacts._counter ==
norm_result.atom_contacts._counter)
assert (dask_result.residue_contacts._counter ==
norm_result.residue_contacts._counter)
Loading

0 comments on commit eb0c208

Please sign in to comment.