-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph_visualize.py
131 lines (112 loc) · 5.34 KB
/
graph_visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
r"""Matplotlib utils for graph plots.
"""
import os
import matplotlib.pyplot as plt
import networkx as nx
from networkx import Graph
import torch
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops, to_networkx, subgraph, k_hop_subgraph
# plt.style.use('seaborn')
def plot_graph(nx_Gs, color_attr=None, colors=None, font_color='white', font_size=12, node_size=300,
arrows=True, vmin=None, vmax=None, pos=None, line_width=1.0,
enable_label=True, label_attr=None, enable_colorbar=True, save_fig_path=None, **kwargs):
r"""
Plot a graph with Matplotlib
Args:
nx_Gs (Graph): input graph (:obj:`networkx.Graph`)
color_attr (str or None): color attribute of nodes
colors (str or list or None): Node color
font_color (str): font color in plot
node_size (int): Size of nodes.
arrows (bool or None): arrowheads style for drawing. If None, directed graphs draw arrowheads with FancyArrowPatch, while undirected graphs draw edges via LineCollection for speed. If True, draw arrowheads with FancyArrowPatches (bendable and stylish). If False, draw edges using LineCollection (linear and fast).
vmin (float or None): minimum for node colormap scaling
vmax (float or None): maximum for node colormap scaling
pos (list or None): positions values of nodes
line_width (float): line width in plot
enable_label (bool): whether to add labels in plot
enable_colorbar (bool): whether to use colorbar in plot
save_fig_path (str or None): path to save plot file
Returns:
None
"""
if colors is not None:
colors = colors
elif color_attr is not None:
colors = list(nx.get_node_attributes(nx_Gs, color_attr).values())
else:
colors = 'red'
# plt.clf()
fig, ax = plt.subplots(dpi=300)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=.001, hspace=.001)
ax.axis('off')
# fig = plt.figure(1)
if pos is None:
pos = nx.kamada_kawai_layout(nx_Gs)
else:
pos = pos.numpy()
pos[:, 1] = - pos[:, 1]
pos = {i: pos[i] for i in range(pos.shape[0])}
# nx_Gs.nodes.data()
# nx.draw(nx_Gs, pos, with_labels=True, node_color=colors, cmap=plt.cm.jet)
ec = nx.draw_networkx_edges(nx_Gs, pos, edgelist=nx_Gs.edges, node_size=node_size, arrowstyle="<|-", arrows=arrows,
ax=ax, width=line_width, edge_color=kwargs.get('edge_color').cpu().numpy(), edge_cmap=plt.cm.get_cmap('hot'), edge_vmin=0, edge_vmax=1)
nc = nx.draw_networkx_nodes(nx_Gs, pos, nodelist=nx_Gs.nodes, node_color=colors, cmap=plt.cm.get_cmap('tab20c'), # plt.cm.get_cmap('brg'
vmin=vmin, vmax=vmax,
ax=ax, node_size=node_size, edgecolors='black')
if enable_label:
lc = nx.draw_networkx_labels(nx_Gs, pos, labels=nx.get_node_attributes(nx_Gs, label_attr), font_color=font_color, font_size=font_size, ax=ax)
if enable_colorbar:
plt.colorbar(nc)
if save_fig_path:
dir_name = os.path.dirname(save_fig_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
fig.savefig(save_fig_path, transparent=True)
else:
fig.show()
plt.close(fig)
def plot_calculation_graph(calculation_graph, graph_idx=0, color_attr=None, k_hop=None, **kwargs):
r"""
Processing of graphs before plot
Args:
calculation_graph (Batch): graph data to be processed
graph_idx (int, list, tuple or torch.Tensor): the central node(s)
color_attr (str or None): color attribute of nodes
k_hop (int): number of hops in graph to plot
**kwargs: key word arguments for the use of :func:`~plot_graph`
Returns:
None
"""
if getattr(calculation_graph, 'batch') is not None:
subset = torch.where(calculation_graph.batch == graph_idx)[0]
else:
if k_hop:
subset = k_hop_subgraph(node_idx=graph_idx, num_hops=k_hop, edge_index=calculation_graph.edge_index)[0]
if subset.shape[0] > 1000:
print(f'The graph is too large. {subset.shape[0]}')
subset = k_hop_subgraph(node_idx=graph_idx, num_hops=1, edge_index=calculation_graph.edge_index)[0]
if subset.shape[0] > 1000:
return
# return None
else:
subset = torch.arange(calculation_graph.x.shape[0], dtype=torch.long, device=calculation_graph.x.device)
vis_edge_index = \
remove_self_loops(subgraph(subset=subset, edge_index=calculation_graph.edge_index, relabel_nodes=True, num_nodes=calculation_graph.x.shape[0])[0])[0]
data = Data(edge_index=vis_edge_index)
node_attrs = []
if color_attr:
data.__setattr__(color_attr, calculation_graph.get(color_attr)[subset])
node_attrs.append(color_attr)
label_attr = kwargs.get('label_attr')
if label_attr:
labels = calculation_graph.get(label_attr)
if type(labels) is list:
label_subset = [labels[i] for i in subset]
else:
label_subset = labels[subset]
data.__setattr__(label_attr, label_subset)
node_attrs.append(label_attr)
vis_nx_G = to_networkx(data, node_attrs=node_attrs)
plot_graph(vis_nx_G, color_attr=color_attr, **kwargs)
print('One graph plotted')