Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DaskContactTrajectory #101

Merged
merged 12 commits into from
Apr 11, 2021
4 changes: 3 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ jobs:

steps:
- uses: actions/checkout@v2
with:
fetch-depth: 2
- uses: actions/setup-python@v2
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-python: true
auto-update-conda: true
python-version: ${{ matrix.CONDA_PY }}
- name: "Install"
env:
Expand Down
2 changes: 1 addition & 1 deletion contact_map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
ConcurrencePlotter, plot_concurrence
)

from .dask_runner import DaskContactFrequency
from .dask_runner import DaskContactFrequency, DaskContactTrajectory

from . import plot_utils
70 changes: 45 additions & 25 deletions contact_map/contact_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
from collections import abc, Counter

from .contact_map import ContactFrequency, ContactObject
import json


# Split this out of the to prevent code duplication for DaskContactTrajectory
def _build_contacts(contact_object, trajectory):
"""Make a contact map for every frame in trajectory.

Parameters
----------
contact_object : `ContactObject`
The contact object that will be used to make the contact maps.
trajectory: `mdtraj.Trajectory`
The trajectory for which we will return a contactObject for each frame.

Returns
-------
out : list of tuples
list of (atom_contacts, residue_contacts) for each frame in trajectory.
"""

# atom_contacts, residue_contacts = self._empty_contacts()
atom_contacts = []
residue_contacts = []
residue_ignore_atom_idxs = contact_object._residue_ignore_atom_idxs
residue_query_atom_idxs = contact_object.indexer.residue_query_atom_idxs
used_trajectory = contact_object.indexer.slice_trajectory(trajectory)

# range(len(trajectory)) avoids recopying topology, as would occur
# in `for frame in trajectory`
for frame_num in range(len(trajectory)):
frame_contacts = contact_object._contact_map(used_trajectory,
frame_num,
residue_query_atom_idxs,
residue_ignore_atom_idxs)
frame_atom_contacts, frame_residue_contacts = frame_contacts
frame_atom_contacts = \
contact_object.indexer.convert_atom_contacts(frame_atom_contacts)
# TODO unify contact building with something like this?
# atom_contacts, residue_contact = self._update_contacts(...)
atom_contacts.append(frame_atom_contacts)
residue_contacts.append(frame_residue_contacts)
return zip(atom_contacts, residue_contacts)


class ContactTrajectory(ContactObject, abc.Sequence):
"""Track all the contacts over a trajectory, frame-by-frame.
Expand Down Expand Up @@ -44,7 +85,7 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45,
n_frames=1,
indexer=self.indexer
)
for atom_contacts, residue_contacts in zip(*contacts)
for atom_contacts, residue_contacts in contacts
]

def __getitem__(self, num):
Expand Down Expand Up @@ -83,28 +124,7 @@ def from_contacts(cls, atom_contacts, residue_contacts, topology,
return cls.from_contact_maps(contact_maps)

def _build_contacts(self, trajectory):
# atom_contacts, residue_contacts = self._empty_contacts()
atom_contacts = []
residue_contacts = []

residue_ignore_atom_idxs = self._residue_ignore_atom_idxs
residue_query_atom_idxs = self.indexer.residue_query_atom_idxs
used_trajectory = self.indexer.slice_trajectory(trajectory)

# range(len(trajectory)) avoids recopying topology, as would occur
# in `for frame in trajectory`
for frame_num in range(len(trajectory)):
frame_contacts = self._contact_map(used_trajectory, frame_num,
residue_query_atom_idxs,
residue_ignore_atom_idxs)
frame_atom_contacts, frame_residue_contacts = frame_contacts
frame_atom_contacts = \
self.indexer.convert_atom_contacts(frame_atom_contacts)
# TODO unify contact building with something like this?
# atom_contacts, residue_contact = self._update_contacts(...)
atom_contacts.append(frame_atom_contacts)
residue_contacts.append(frame_residue_contacts)
return atom_contacts, residue_contacts
return _build_contacts(self, trajectory)

def contact_frequency(self):
"""Create a :class:`.ContactFrequency` from this contact trajectory
Expand Down Expand Up @@ -297,7 +317,7 @@ def _normal(self):

def __next__(self):
# if self.max + self.step < self.width:
# to_add, to_sub = self._startup()
# to_add, to_sub = self._startup()
if self.max + self.step < self.length:
to_add, to_sub = self._normal()
else:
Expand Down
117 changes: 111 additions & 6 deletions contact_map/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import frequency_task
from .contact_map import ContactFrequency, ContactObject
from .contact_trajectory import ContactTrajectory
import mdtraj as md


Expand All @@ -27,19 +28,41 @@ def dask_run(trajectory, client, run_info):
:class:`.ContactFrequency` :
total contact frequency for the trajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))
subtrajs = dask_load_and_slice(trajectory, client, run_info)

subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
maps = client.map(frequency_task.map_task, subtrajs,
parameters=run_info['parameters'])
freq = client.submit(frequency_task.reduce_all_results, maps)

return freq.result()


def dask_load_and_slice(trajectory, client, run_info):
"""
Run dask to load a trajectory and make a list of subtraj futures that
can be used for the other analysis

Parameters
----------
trajectory : mdtraj.Trajectory
client : dask.distributed.Client
run_info : dict
Keys are 'trajectory_file' (trajectory filename) and 'load_kwargs'
(additional kwargs passed to md.load)

Returns
-------
list of Futures :
A list of futures for loaded subtrajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))
subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
return subtrajs


class DaskContactFrequency(ContactFrequency):
"""Dask-based parallelization of contact frequency.

Expand Down Expand Up @@ -84,7 +107,6 @@ def __init__(self, client, filename, query=None, haystack=None,
self.client = client
self.filename = filename
trajectory = md.load(filename, **kwargs)

self.kwargs = kwargs

super(DaskContactFrequency, self).__init__(
Expand All @@ -108,3 +130,86 @@ def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}


class DaskContactTrajectory(ContactTrajectory):
"""Dask-based parallelization of contact trajectory.

The contact trajectory tracks all contacts of a trajectory, frame-by-frame.
See :class:`.ContactTrajectory` for details. This implementation
parallelizes the contact map calculations using
``dask.distributed``, which must be installed separately to use this
object.

Notes
-----

The interface for this object closely mimics that of the
:class:`.ContactTrajectory` object, with the addition requiring the
``dask.distributed.Client`` as input. However, there is one important
difference. Whereas :class:`.ContactTrajectory` takes an
``mdtraj.Trajectory`` object as input, :class:`.DaskContactTrajectory`
takes a file name, plus any extra kwargs that MDTraj needs to load the
file.

Parameters
----------
client : dask.distributed.Client
Client object connected to the dask network.
filename : str
Name of the file where the trajectory is located. File must be
accessible by all workers in the dask network.
query : list of int
Indices of the atoms to be included as query. Default ``None``
means all atoms.
haystack : list of int
Indices of the atoms to be included as haystack. Default ``None``
means all atoms.
cutoff : float
Cutoff distance for contacts, in nanometers. Default 0.45.
n_neighbors_ignored : int
Number of neighboring residues (in the same chain) to ignore.
Default 2.
"""
def __init__(self, client, filename, query=None, haystack=None,
cutoff=0.45, n_neighbors_ignored=2, **kwargs):
self.client = client
self.filename = filename
self.kwargs = kwargs
self.trajectory = md.load(filename, **kwargs)
self.contact_object = ContactObject(
self.trajectory.topology,
query=query,
haystack=haystack,
cutoff=cutoff,
n_neighbors_ignored=n_neighbors_ignored,
)

super(DaskContactTrajectory, self).__init__(
self.trajectory, query, haystack, cutoff, n_neighbors_ignored,
)

def _build_contacts(self, trajectory):
subtrajs = dask_load_and_slice(self.trajectory, self.client,
self.run_info)
contact_lists = self.client.map(frequency_task.contacts_per_frame_task,
subtrajs,
contact_object=self.contact_object)
# Return a generator for this to work out
gen = ((atom_contacts, residue_contacts)
for contacts in contact_lists
for atom_contacts, residue_contacts in contacts.result())
return gen

@property
def parameters(self):
return {'query': self.query,
'haystack': self.haystack,
'cutoff': self.cutoff,
'n_neighbors_ignored': self.n_neighbors_ignored}

@property
def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}
19 changes: 19 additions & 0 deletions contact_map/frequency_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import mdtraj as md
from contact_map import ContactFrequency
from contact_map.contact_trajectory import _build_contacts

def block_slices(n_total, n_per_block):
"""Determine slices for splitting the input array.
Expand Down Expand Up @@ -86,6 +87,7 @@ def load_trajectory_task(subslice, file_name, **kwargs):
"""
return md.load(file_name, **kwargs)[subslice]


def map_task(subtrajectory, parameters):
"""Task to be mapped to all subtrajectories. Run ContactFrequency

Expand All @@ -103,6 +105,23 @@ def map_task(subtrajectory, parameters):
"""
return ContactFrequency(subtrajectory, **parameters)


def contacts_per_frame_task(trajectory, contact_object):
"""Task that will mimic ContactTrajectory._build_contacts, but with
a pre-initialized ContactObject instead of `self` to produce the contacts


Parameters
----------
trajectory : mdtraj.Trajectory
single trajectory segment to calculate contacts for every frame
contactobject : ContactObject
The instance that will replace self in _build_contacts

"""
return _build_contacts(contact_object, trajectory)


def reduce_all_results(contacts):
"""Combine multiple :class:`.ContactFrequency` objects into one

Expand Down
Loading