/
gcn.py
354 lines (314 loc) · 13.2 KB
/
gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
DGL-based GCN for graph property prediction.
"""
import torch.nn as nn
import torch.nn.functional as F
from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.torch_models.torch_model import TorchModel
class GCN(nn.Module):
"""Model for Graph Property Prediction Based on Graph Convolution Networks (GCN).
This model proceeds as follows:
* Update node representations in graphs with a variant of GCN
* For each graph, compute its representation by 1) a weighted sum of the node
representations in the graph, where the weights are computed by applying a
gating function to the node representations 2) a max pooling of the node
representations 3) concatenating the output of 1) and 2)
* Perform the final prediction using an MLP
Examples
--------
>>> import deepchem as dc
>>> import dgl
>>> from deepchem.models import GCN
>>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
>>> featurizer = dc.feat.MolGraphConvFeaturizer()
>>> graphs = featurizer.featurize(smiles)
>>> print(type(graphs[0]))
<class 'deepchem.feat.graph_data.GraphData'>
>>> dgl_graphs = [graphs[i].to_dgl_graph(self_loop=True) for i in range(len(graphs))]
>>> # Batch two graphs into a graph of two connected components
>>> batch_dgl_graph = dgl.batch(dgl_graphs)
>>> model = GCN(n_tasks=1, mode='regression')
>>> preds = model(batch_dgl_graph)
>>> print(type(preds))
<class 'torch.Tensor'>
>>> preds.shape == (2, 1)
True
References
----------
.. [1] Thomas N. Kipf and Max Welling. "Semi-Supervised Classification with Graph
Convolutional Networks." ICLR 2017.
Notes
-----
This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
(https://github.com/awslabs/dgl-lifesci) to be installed.
This model is different from deepchem.models.GraphConvModel as follows:
* For each graph convolution, the learnable weight in this model is shared across all nodes.
``GraphConvModel`` employs separate learnable weights for nodes of different degrees. A
learnable weight is shared across all nodes of a particular degree.
* For ``GraphConvModel``, there is an additional GraphPool operation after each
graph convolution. The operation updates the representation of a node by applying an
element-wise maximum over the representations of its neighbors and itself.
* For computing graph-level representations, this model computes a weighted sum and an
element-wise maximum of the representations of all nodes in a graph and concatenates them.
The node weights are obtained by using a linear/dense layer followd by a sigmoid function.
For ``GraphConvModel``, the sum over node representations is unweighted.
* There are various minor differences in using dropout, skip connection and batch
normalization.
"""
def __init__(self,
n_tasks: int,
graph_conv_layers: list = None,
activation=None,
residual: bool = True,
batchnorm: bool = False,
dropout: float = 0.,
predictor_hidden_feats: int = 128,
predictor_dropout: float = 0.,
mode: str = 'regression',
number_atom_features: int = 30,
n_classes: int = 2,
nfeat_name: str = 'x'):
"""
Parameters
----------
n_tasks: int
Number of tasks.
graph_conv_layers: list of int
Width of channels for GCN layers. graph_conv_layers[i] gives the width of channel
for the i-th GCN layer. If not specified, the default value will be [64, 64].
activation: callable
The activation function to apply to the output of each GCN layer.
By default, no activation function will be applied.
residual: bool
Whether to add a residual connection within each GCN layer. Default to True.
batchnorm: bool
Whether to apply batch normalization to the output of each GCN layer.
Default to False.
dropout: float
The dropout probability for the output of each GCN layer. Default to 0.
predictor_hidden_feats: int
The size for hidden representations in the output MLP predictor. Default to 128.
predictor_dropout: float
The dropout probability in the output MLP predictor. Default to 0.
mode: str
The model type, 'classification' or 'regression'. Default to 'regression'.
number_atom_features: int
The length of the initial atom feature vectors. Default to 30.
n_classes: int
The number of classes to predict per task
(only used when ``mode`` is 'classification'). Default to 2.
nfeat_name: str
For an input graph ``g``, the model assumes that it stores node features in
``g.ndata[nfeat_name]`` and will retrieve input node features from that.
Default to 'x'.
"""
try:
import dgl
except:
raise ImportError('This class requires dgl.')
try:
import dgllife
except:
raise ImportError('This class requires dgllife.')
if mode not in ['classification', 'regression']:
raise ValueError("mode must be either 'classification' or 'regression'")
super(GCN, self).__init__()
self.n_tasks = n_tasks
self.mode = mode
self.n_classes = n_classes
self.nfeat_name = nfeat_name
if mode == 'classification':
out_size = n_tasks * n_classes
else:
out_size = n_tasks
from dgllife.model import GCNPredictor as DGLGCNPredictor
if graph_conv_layers is None:
graph_conv_layers = [64, 64]
num_gnn_layers = len(graph_conv_layers)
if activation is not None:
activation = [activation] * num_gnn_layers
self.model = DGLGCNPredictor(
in_feats=number_atom_features,
hidden_feats=graph_conv_layers,
activation=activation,
residual=[residual] * num_gnn_layers,
batchnorm=[batchnorm] * num_gnn_layers,
dropout=[dropout] * num_gnn_layers,
n_tasks=out_size,
predictor_hidden_feats=predictor_hidden_feats,
predictor_dropout=predictor_dropout)
def forward(self, g):
"""Predict graph labels
Parameters
----------
g: DGLGraph
A DGLGraph for a batch of graphs. It stores the node features in
``dgl_graph.ndata[self.nfeat_name]``.
Returns
-------
torch.Tensor
The model output.
* When self.mode = 'regression',
its shape will be ``(dgl_graph.batch_size, self.n_tasks)``.
* When self.mode = 'classification', the output consists of probabilities
for classes. Its shape will be ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)``
if self.n_tasks > 1; its shape will be ``(dgl_graph.batch_size, self.n_classes)`` if
self.n_tasks is 1.
torch.Tensor, optional
This is only returned when self.mode = 'classification', the output consists of the
logits for classes before softmax.
"""
node_feats = g.ndata[self.nfeat_name]
out = self.model(g, node_feats)
if self.mode == 'classification':
if self.n_tasks == 1:
logits = out.view(-1, self.n_classes)
softmax_dim = 1
else:
logits = out.view(-1, self.n_tasks, self.n_classes)
softmax_dim = 2
proba = F.softmax(logits, dim=softmax_dim)
return proba, logits
else:
return out
class GCNModel(TorchModel):
"""Model for Graph Property Prediction Based on Graph Convolution Networks (GCN).
This model proceeds as follows:
* Update node representations in graphs with a variant of GCN
* For each graph, compute its representation by 1) a weighted sum of the node
representations in the graph, where the weights are computed by applying a
gating function to the node representations 2) a max pooling of the node
representations 3) concatenating the output of 1) and 2)
* Perform the final prediction using an MLP
Examples
--------
>>>
>> import deepchem as dc
>> from deepchem.models import GCNModel
>> featurizer = dc.feat.MolGraphConvFeaturizer()
>> tasks, datasets, transformers = dc.molnet.load_tox21(
.. reload=False, featurizer=featurizer, transformers=[])
>> train, valid, test = datasets
>> model = GCNModel(mode='classification', n_tasks=len(tasks),
.. batch_size=32, learning_rate=0.001)
>> model.fit(train, nb_epoch=50)
References
----------
.. [1] Thomas N. Kipf and Max Welling. "Semi-Supervised Classification with Graph
Convolutional Networks." ICLR 2017.
Notes
-----
This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
(https://github.com/awslabs/dgl-lifesci) to be installed.
This model is different from deepchem.models.GraphConvModel as follows:
* For each graph convolution, the learnable weight in this model is shared across all nodes.
``GraphConvModel`` employs separate learnable weights for nodes of different degrees. A
learnable weight is shared across all nodes of a particular degree.
* For ``GraphConvModel``, there is an additional GraphPool operation after each
graph convolution. The operation updates the representation of a node by applying an
element-wise maximum over the representations of its neighbors and itself.
* For computing graph-level representations, this model computes a weighted sum and an
element-wise maximum of the representations of all nodes in a graph and concatenates them.
The node weights are obtained by using a linear/dense layer followd by a sigmoid function.
For ``GraphConvModel``, the sum over node representations is unweighted.
* There are various minor differences in using dropout, skip connection and batch
normalization.
"""
def __init__(self,
n_tasks: int,
graph_conv_layers: list = None,
activation=None,
residual: bool = True,
batchnorm: bool = False,
dropout: float = 0.,
predictor_hidden_feats: int = 128,
predictor_dropout: float = 0.,
mode: str = 'regression',
number_atom_features=30,
n_classes: int = 2,
self_loop: bool = True,
**kwargs):
"""
Parameters
----------
n_tasks: int
Number of tasks.
graph_conv_layers: list of int
Width of channels for GCN layers. graph_conv_layers[i] gives the width of channel
for the i-th GCN layer. If not specified, the default value will be [64, 64].
activation: callable
The activation function to apply to the output of each GCN layer.
By default, no activation function will be applied.
residual: bool
Whether to add a residual connection within each GCN layer. Default to True.
batchnorm: bool
Whether to apply batch normalization to the output of each GCN layer.
Default to False.
dropout: float
The dropout probability for the output of each GCN layer. Default to 0.
predictor_hidden_feats: int
The size for hidden representations in the output MLP predictor. Default to 128.
predictor_dropout: float
The dropout probability in the output MLP predictor. Default to 0.
mode: str
The model type, 'classification' or 'regression'. Default to 'regression'.
number_atom_features: int
The length of the initial atom feature vectors. Default to 30.
n_classes: int
The number of classes to predict per task
(only used when ``mode`` is 'classification'). Default to 2.
self_loop: bool
Whether to add self loops for the nodes, i.e. edges from nodes to themselves.
When input graphs have isolated nodes, self loops allow preserving the original feature
of them in message passing. Default to True.
kwargs
This can include any keyword argument of TorchModel.
"""
model = GCN(
n_tasks=n_tasks,
graph_conv_layers=graph_conv_layers,
activation=activation,
residual=residual,
batchnorm=batchnorm,
dropout=dropout,
predictor_hidden_feats=predictor_hidden_feats,
predictor_dropout=predictor_dropout,
mode=mode,
number_atom_features=number_atom_features,
n_classes=n_classes)
if mode == 'regression':
loss: Loss = L2Loss()
output_types = ['prediction']
else:
loss = SparseSoftmaxCrossEntropy()
output_types = ['prediction', 'loss']
super(GCNModel, self).__init__(
model, loss=loss, output_types=output_types, **kwargs)
self._self_loop = self_loop
def _prepare_batch(self, batch):
"""Create batch data for GCN.
Parameters
----------
batch: tuple
The tuple is ``(inputs, labels, weights)``.
Returns
-------
inputs: DGLGraph
DGLGraph for a batch of graphs.
labels: list of torch.Tensor or None
The graph labels.
weights: list of torch.Tensor or None
The weights for each sample or sample/task pair converted to torch.Tensor.
"""
try:
import dgl
except:
raise ImportError('This class requires dgl.')
inputs, labels, weights = batch
dgl_graphs = [
graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]
]
inputs = dgl.batch(dgl_graphs).to(self.device)
_, labels, weights = super(GCNModel, self)._prepare_batch(([], labels,
weights))
return inputs, labels, weights