Skip to content

Commit

Permalink
Merge pull request #90 from dwhswenson/to_networkx
Browse files Browse the repository at this point in the history
ContactCount.to_networkx
  • Loading branch information
dwhswenson committed Jun 15, 2021
2 parents ae5deca + 73cca8d commit 914dfe1
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 12 deletions.
42 changes: 42 additions & 0 deletions contact_map/contact_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
else:
HAS_MATPLOTLIB = True

try:
import networkx as nx
except ImportError:
HAS_NETWORKX = False
else:
HAS_NETWORKX = True

# pandas 0.25 not available on py27; can drop this when we drop py27
_PD_VERSION = tuple(int(x) for x in pd.__version__.split('.')[:2])

Expand Down Expand Up @@ -164,6 +171,41 @@ def df(self):
df = df.astype(pd.SparseDtype("float", np.nan))
return df

def to_networkx(self, weighted=True, as_index=False, graph=None):
"""Graph representation of contacts (requires networkx)
Parameters
----------
weighted : bool
whether to use the frequencies as edge weights in the graph,
default True
as_index : bool
if True, the nodes in the graph are integer indices; if False
(default), the nodes are mdtraj.topology objects (Atom/Residue)
graph : networkx.Graph or None
if provided, edges are added to an existing graph
Returns
-------
networkx.Graph :
graph representation of the contact matrix
"""
if not HAS_NETWORKX: # -no-cov-
raise RuntimeError("Error importing networkx")

graph = nx.Graph() if graph is None else graph

for pair, value in self.counter.items():
if not as_index:
pair = map(self._object_f, pair)

attr_dict = {'weight': value} if weighted else {}

graph.add_edge(*pair, **attr_dict)

return graph


def _check_number_of_pixels(self, figure):
"""
This checks to see if the number of pixels in the figure is high enough
Expand Down
39 changes: 39 additions & 0 deletions contact_map/tests/test_contact_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,45 @@ def test_pixel_warning(self):
self.atom_contacts.plot(figsize=(4, 4), dpi=2)
assert len(record) == 1

@pytest.mark.skipif(not HAS_NETWORKX, reason="Missing networkx")
@pytest.mark.parametrize('weighted', [True, False])
def test_to_networkx(self, weighted):
as_index = self.residue_contacts.to_networkx(
as_index=True, weighted=weighted
)
as_res = self.residue_contacts.to_networkx(weighted=weighted)

mappings = [lambda x: x, self.map.topology.residue]

for graph, mapping in zip([as_index, as_res], mappings):
assert len(graph.edges) == 4
assert len(graph.nodes) == 4
if weighted:
assert graph[mapping(0)][mapping(4)]['weight'] == 0.2
assert graph[mapping(4)][mapping(0)]['weight'] == 0.2
else:
edge = graph[mapping(0)][mapping(4)]
with pytest.raises(KeyError):
edge['weight']
with pytest.raises(KeyError):
graph[mapping(1)][mapping(0)]

@pytest.mark.skipif(not HAS_NETWORKX, reason="Missing networkx")
def test_to_networkx_existing(self):
import networkx as nx
graph = nx.Graph()
graph.add_edge(5, 6, weight=1.0)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph[5][6]['weight'] == 1.0
graph = self.residue_contacts.to_networkx(as_index=True,
graph=graph)
assert len(graph.nodes) == 6
assert len(graph.edges) == 5
assert graph[5][6]['weight'] == 1.0
assert graph[0][4]['weight'] == 0.2
assert graph[4][0]['weight'] == 0.2

def test_initialization(self):
assert self.atom_contacts._object_f == self.topology.atom
assert self.atom_contacts.n_x == self.map.query_range
Expand Down
56 changes: 44 additions & 12 deletions examples/exporting_data.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions optional_installs.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
matplotlib
dask
distributed
networkx

0 comments on commit 914dfe1

Please sign in to comment.