From f766c92a63fb67a99744dcb82c8c78406fbf905a Mon Sep 17 00:00:00 2001 From: mufeili Date: Wed, 4 Nov 2020 18:07:30 +0800 Subject: [PATCH 1/8] Update --- deepchem/models/tests/test_gat.py | 107 +++--- deepchem/models/torch_models/gat.py | 533 +++++++++++++++++----------- deepchem/models/torch_models/gcn.py | 2 +- 3 files changed, 381 insertions(+), 261 deletions(-) diff --git a/deepchem/models/tests/test_gat.py b/deepchem/models/tests/test_gat.py index b37fa82648..b889be6532 100644 --- a/deepchem/models/tests/test_gat.py +++ b/deepchem/models/tests/test_gat.py @@ -9,15 +9,16 @@ from deepchem.models.tests.test_graph_models import get_dataset try: - import torch # noqa - import torch_geometric # noqa - has_pytorch_and_pyg = True + import dgl + import dgllife + import torch + has_torch_and_dgl = True except: - has_pytorch_and_pyg = False + has_torch_and_dgl = False -@unittest.skipIf(not has_pytorch_and_pyg, - 'PyTorch and PyTorch Geometric are not installed') +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') def test_gat_regression(): # load datasets featurizer = MolGraphConvFeaturizer() @@ -26,17 +27,20 @@ def test_gat_regression(): # initialize models n_tasks = len(tasks) - model = GATModel(mode='regression', n_tasks=n_tasks, batch_size=10) + model = GATModel( + mode='regression', + n_tasks=n_tasks, + number_atom_features=30, + batch_size=10) # overfit test - # GAT's convergence is a little slow - model.fit(dataset, nb_epoch=300) + model.fit(dataset, nb_epoch=100) scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean_absolute_error'] < 0.75 + assert scores['mean_absolute_error'] < 0.5 -@unittest.skipIf(not has_pytorch_and_pyg, - 'PyTorch and PyTorch Geometric are not installed') +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') def test_gat_classification(): # load datasets featurizer = MolGraphConvFeaturizer() @@ -48,49 +52,50 @@ def test_gat_classification(): model = GATModel( mode='classification', n_tasks=n_tasks, + number_atom_features=30, batch_size=10, learning_rate=0.001) # overfit test - # GAT's convergence is a little slow - model.fit(dataset, nb_epoch=150) + model.fit(dataset, nb_epoch=50) scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean-roc_auc_score'] >= 0.70 - + assert scores['mean-roc_auc_score'] >= 0.85 -@unittest.skipIf(not has_pytorch_and_pyg, - 'PyTorch and PyTorch Geometric are not installed') +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') def test_gat_reload(): - # load datasets - featurizer = MolGraphConvFeaturizer() - tasks, dataset, transformers, metric = get_dataset( - 'classification', featurizer=featurizer) - - # initialize models - n_tasks = len(tasks) - model_dir = tempfile.mkdtemp() - model = GATModel( - mode='classification', - n_tasks=n_tasks, - model_dir=model_dir, - batch_size=10, - learning_rate=0.001) - - model.fit(dataset, nb_epoch=150) - scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean-roc_auc_score'] >= 0.70 - - reloaded_model = GATModel( - mode='classification', - n_tasks=n_tasks, - model_dir=model_dir, - batch_size=10, - learning_rate=0.001) - reloaded_model.restore() - - pred_mols = ["CCCC", "CCCCCO", "CCCCC"] - X_pred = featurizer(pred_mols) - random_dataset = dc.data.NumpyDataset(X_pred) - original_pred = model.predict(random_dataset) - reload_pred = reloaded_model.predict(random_dataset) - assert np.all(original_pred == reload_pred) + # load datasets + featurizer = MolGraphConvFeaturizer() + tasks, dataset, transformers, metric = get_dataset( + 'classification', featurizer=featurizer) + + # initialize models + n_tasks = len(tasks) + model_dir = tempfile.mkdtemp() + model = GATModel( + mode='classification', + n_tasks=n_tasks, + number_atom_features=30, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + + model.fit(dataset, nb_epoch=50) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.85 + + reloaded_model = GATModel( + mode='classification', + n_tasks=n_tasks, + number_atom_features=30, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + reloaded_model.restore() + + pred_mols = ["CCCC", "CCCCCO", "CCCCC"] + X_pred = featurizer(pred_mols) + random_dataset = dc.data.NumpyDataset(X_pred) + original_pred = model.predict(random_dataset) + reload_pred = reloaded_model.predict(random_dataset) + assert np.all(original_pred == reload_pred) diff --git a/deepchem/models/torch_models/gat.py b/deepchem/models/torch_models/gat.py index fca8dd2522..833bb3315d 100644 --- a/deepchem/models/torch_models/gat.py +++ b/deepchem/models/torch_models/gat.py @@ -1,224 +1,334 @@ """ -This is a sample implementation for working PyTorch Geometric with DeepChem! +DGL-based GAT for graph property prediction. """ -import torch import torch.nn as nn import torch.nn.functional as F -from deepchem.models.torch_models.torch_model import TorchModel from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy - +from deepchem.models.torch_models.torch_model import TorchModel class GAT(nn.Module): - """Graph Attention Networks. - - This model takes arbitary graphs as an input, and predict graph properties. This model is - one of variants of Graph Convolutional Networks. The main difference between basic GCN models - is how to update node representations. The GAT uses multi head attention mechanisms which - outbroke in NLP like Transformer when updating node representations. The most important advantage - of this approach is that we can get the interpretability like how the model predict the value - or which part of the graph structure is important from attention-weight. Please confirm - the detail algorithms from [1]_. - - Examples - -------- - >>> import deepchem as dc - >>> from torch_geometric.data import Batch - >>> smiles = ["C1CCC1", "C1=CC=CN=C1"] - >>> featurizer = dc.feat.MolGraphConvFeaturizer() - >>> graphs = featurizer.featurize(smiles) - >>> print(type(graphs[0])) - - >>> pyg_graphs = [graph.to_pyg_graph() for graph in graphs] - >>> print(type(pyg_graphs[0])) - - >>> model = dc.models.GAT(mode='classification', n_tasks=10, n_classes=2) - >>> preds, logits = model(Batch.from_data_list(pyg_graphs)) - >>> print(type(preds)) - - >>> preds.shape == (2, 10, 2) - True - - References - ---------- - .. [1] Veličković, Petar, et al. "Graph attention networks." arXiv preprint - arXiv:1710.10903 (2017). - - Notes - ----- - This class requires PyTorch Geometric to be installed. - """ - - def __init__( - self, - in_node_dim: int = 30, - hidden_node_dim: int = 32, - heads: int = 1, - dropout: float = 0.0, - num_conv: int = 2, - predictor_hidden_feats: int = 64, - n_tasks: int = 1, - mode: str = 'classification', - n_classes: int = 2, - ): - """ - Parameters + """Model for Graph Property Prediction Based on Graph Attention Networks (GAT). + + This model proceeds as follows: + + * Update node representations in graphs with a variant of GAT + * For each graph, compute its representation by 1) a weighted sum of the node + representations in the graph, where the weights are computed by applying a + gating function to the node representations 2) a max pooling of the node + representations 3) concatenating the output of 1) and 2) + * Perform the final prediction using an MLP + + Examples + -------- + + >>> import deepchem as dc + >>> import dgl + >>> from deepchem.models import GAT + >>> smiles = ["C1CCC1", "C1=CC=CN=C1"] + >>> featurizer = dc.feat.MolGraphConvFeaturizer() + >>> graphs = featurizer.featurize(smiles) + >>> print(type(graphs[0])) + + >>> dgl_graphs = [graphs[i].to_dgl_graph() for i in range(len(graphs))] + >>> # Batch two graphs into a graph of two connected components + >>> batch_dgl_graph = dgl.batch(dgl_graphs) + >>> model = GAT(n_tasks=1, number_atom_features=30, mode='regression') + >>> preds = model(batch_dgl_graph) + >>> print(type(preds)) + + >>> preds.shape == (2, 1) + True + + References ---------- - in_node_dim: int, default 30 - The length of the initial node feature vectors. The 30 is - based on `MolGraphConvFeaturizer`. - hidden_node_dim: int, default 32 - The length of the hidden node feature vectors. - heads: int, default 1 - The number of multi-head-attentions. - dropout: float, default 0.0 - The dropout probability for each convolutional layer. - num_conv: int, default 2 - The number of convolutional layers. - predictor_hidden_feats: int, default 64 - The size for hidden representations in the output MLP predictor, default to 64. - n_tasks: int, default 1 - The number of the output size, default to 1. - mode: str, default 'classification' - The model type, 'classification' or 'regression'. - n_classes: int, default 2 - The number of classes to predict (only used in classification mode). + .. [1] Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, + and Yoshua Bengio. "Graph Attention Networks." ICLR 2018. + + Notes + ----- + This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci + (https://github.com/awslabs/dgl-lifesci) to be installed. """ - super(GAT, self).__init__() + def __init__(self, + n_tasks: int, + graph_attention_layers: list = None, + n_attention_heads: int = 8, + agg_modes: list = None, + activation=F.elu, + residual: bool = True, + dropout: float = 0., + alpha: float = 0.2, + predictor_hidden_feats: int = 128, + predictor_dropout: float = 0., + mode: str = 'regression', + number_atom_features: int = 75, + n_classes: int = 2, + nfeat_name: str = 'x'): + """ + Parameters + ---------- + n_tasks: int + Number of tasks. + graph_attention_layers: list of int + Width of channels per attention head for GAT layers. graph_attention_layers[i] + gives the width of channel for each attention head for the i-th GAT layer. If + both ``graph_attention_layers`` and ``agg_modes`` are specified, they should have + equal length. If not specified, the default value will be [8, 8]. + n_attention_heads: int + Number of attention heads in each GAT layer. + agg_modes: list of str + The way to aggregate multi-head attention results for each GAT layer, which can be + either 'flatten' for concatenating all-head results or 'mean' for averaging all-head + results. ``agg_modes[i]`` gives the way to aggregate multi-head attention results for + the i-th GAT layer. If both ``graph_attention_layers`` and ``agg_modes`` are + specified, they should have equal length. If not specified, the model will flatten + multi-head results for intermediate GAT layers and compute mean of multi-head results + for the last GAT layer. + activation: activation function or None + The activation function to apply to the aggregated multi-head results for each GAT + layer. If not specified, the default value will be ELU. + residual: bool + Whether to add a residual connection within each GAT layer. Default to True. + dropout: float + The dropout probability within each GAT layer. Default to 0. + alpha: float + A hyperparameter in LeakyReLU, which is the slope for negative values. Default to 0.2. + predictor_hidden_feats: int + The size for hidden representations in the output MLP predictor. Default to 128. + predictor_dropout: float + The dropout probability in the output MLP predictor. Default to 0. + mode: str + The model type, 'classification' or 'regression'. + number_atom_features: int + The length of the initial atom feature vectors. Default to 75. + n_classes: int + The number of classes to predict per task + (only used when ``mode`` is 'classification'). + nfeat_name: str + For an input graph ``g``, the model assumes that it stores node features in + ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + """ try: - from torch_geometric.nn import GATConv, global_mean_pool + import dgl except: - raise ImportError( - "This class requires PyTorch Geometric to be installed.") + raise ImportError('This class requires dgl.') + try: + import dgllife + except: + raise ImportError('This class requires dgllife.') + + if mode not in ['classification', 'regression']: + raise ValueError("mode must be either 'classification' or 'regression'") + + super(GAT, self).__init__() self.n_tasks = n_tasks self.mode = mode self.n_classes = n_classes - self.embedding = nn.Linear(in_node_dim, hidden_node_dim) - self.conv_layers = nn.ModuleList([ - GATConv( - in_channels=hidden_node_dim, - out_channels=hidden_node_dim, - heads=heads, - concat=False, - dropout=dropout) for _ in range(num_conv) - ]) - self.pooling = global_mean_pool - self.fc = nn.Linear(hidden_node_dim, predictor_hidden_feats) - if self.mode == 'regression': - self.out = nn.Linear(predictor_hidden_feats, n_tasks) + self.nfeat_name = nfeat_name + if mode == 'classification': + out_size = n_tasks * n_classes else: - self.out = nn.Linear(predictor_hidden_feats, n_tasks * n_classes) + out_size = n_tasks - def forward(self, data): - """Predict labels + from dgllife.model import GATPredictor as DGLGATPredictor - Parameters - ---------- - data: torch_geometric.data.Batch - A mini-batch graph data for PyTorch Geometric models. - - Returns - ------- - out: torch.Tensor - If mode == 'regression', the shape is `(batch_size, n_tasks)`. - If mode == 'classification', the shape is `(batch_size, n_tasks, n_classes)` (n_tasks > 1) - or `(batch_size, n_classes)` (n_tasks == 1) and the output values are probabilities of each class label. - """ - node_feat, edge_index = data.x, data.edge_index - node_feat = self.embedding(node_feat) + if isinstance(graph_attention_layers, list) and isinstance(agg_modes, list): + assert len(graph_attention_layers) == len(agg_modes), \ + 'Expect graph_attention_layers and agg_modes to have equal length, ' \ + 'got {:d} and {:d}'.format(len(graph_attention_layers), len(agg_modes)) - # convolutional layer - for conv in self.conv_layers: - node_feat = conv(node_feat, edge_index) + # Decide first number of GAT layers + if graph_attention_layers is not None: + num_gnn_layers = len(graph_attention_layers) + elif agg_modes is not None: + num_gnn_layers = len(agg_modes) + else: + num_gnn_layers = 2 - # pooling - graph_feat = self.pooling(node_feat, data.batch) - graph_feat = F.leaky_relu(self.fc(graph_feat)) - out = self.out(graph_feat) + if graph_attention_layers is None: + graph_attention_layers = [8] * num_gnn_layers + if agg_modes is None: + agg_modes = ['flatten' for _ in range(num_gnn_layers - 1)] + agg_modes.append('mean') - if self.mode == 'regression': - return out - else: - logits = out.view(-1, self.n_tasks, self.n_classes) - # for n_tasks == 1 case - logits = torch.squeeze(logits) - proba = F.softmax(logits, dim=-1) + if activation is not None: + activation = [activation] * num_gnn_layers + + self.model = DGLGATPredictor( + in_feats=number_atom_features, + hidden_feats=graph_attention_layers, + num_heads=[n_attention_heads] * num_gnn_layers, + feat_drops=[dropout] * num_gnn_layers, + attn_drops=[dropout] * num_gnn_layers, + alphas=[alpha] * num_gnn_layers, + residuals=[residual] * num_gnn_layers, + agg_modes=agg_modes, + activations=activation, + n_tasks=out_size, + predictor_hidden_feats=predictor_hidden_feats, + predictor_dropout=predictor_dropout + ) + + def forward(self, g): + """Predict graph labels + + Parameters + ---------- + g: DGLGraph + A DGLGraph for a batch of graphs. It stores the node features in + ``dgl_graph.ndata[self.nfeat_name]``. + + Returns + ------- + torch.Tensor + The model output. + + * When self.mode = 'regression', + its shape will be ``(dgl_graph.batch_size, self.n_tasks)``. + * When self.mode = 'classification', the output consists of probabilities + for classes. Its shape will be + ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)`` if self.n_tasks > 1; + its shape will be ``(dgl_graph.batch_size, self.n_classes)`` if self.n_tasks is 1. + torch.Tensor, optional + This is only returned when self.mode = 'classification', the output consists of the + logits for classes before softmax. + """ + node_feats = g.ndata[self.nfeat_name] + out = self.model(g, node_feats) + + if self.mode == 'classification': + if self.n_tasks == 1: + logits = out.view(-1, self.n_classes) + softmax_dim = 1 + else: + logits = out.view(-1, self.n_tasks, self.n_classes) + softmax_dim = 2 + proba = F.softmax(logits, dim=softmax_dim) return proba, logits + else: + return out class GATModel(TorchModel): - """Graph Attention Networks (GAT). - - Here is a simple example of code that uses the GATModel with - molecules dataset. - - >> import deepchem as dc - >> featurizer = dc.feat.MolGraphConvFeaturizer() - >> tasks, datasets, transformers = dc.molnet.load_tox21(reload=False, featurizer=featurizer, transformers=[]) - >> train, valid, test = datasets - >> model = dc.models.GATModel(mode='classification', n_tasks=len(tasks), batch_size=32, learning_rate=0.001) - >> model.fit(train, nb_epoch=50) - - This model takes arbitary graphs as an input, and predict graph properties. This model is - one of variants of Graph Convolutional Networks. The main difference between basic GCN models - is how to update node representations. The GAT uses multi head attention mechanisms which - outbroke in NLP like Transformer when updating node representations. The most important advantage - of this approach is that we can get the interpretability like how the model predict the value - or which part of the graph structure is important from attention-weight. Please confirm - the detail algorithms from [1]_. - - References - ---------- - .. [1] Veličković, Petar, et al. "Graph attention networks." arXiv preprint - arXiv:1710.10903 (2017). - - Notes - ----- - This class requires PyTorch Geometric to be installed. - """ + """Model for Graph Property Prediction Based on Graph Attention Networks (GAT). + + This model proceeds as follows: + + * Update node representations in graphs with a variant of GAT + * For each graph, compute its representation by 1) a weighted sum of the node + representations in the graph, where the weights are computed by applying a + gating function to the node representations 2) a max pooling of the node + representations 3) concatenating the output of 1) and 2) + * Perform the final prediction using an MLP + + Examples + -------- + >>> + >> import deepchem as dc + >> from deepchem.models import GATModel + >> featurizer = dc.feat.MolGraphConvFeaturizer() + >> tasks, datasets, transformers = dc.molnet.load_tox21( + .. reload=False, featurizer=featurizer, transformers=[]) + >> train, valid, test = datasets + >> model = dc.models.GATModel(mode='classification', n_tasks=len(tasks), + .. number_atom_features=30, batch_size=32, learning_rate=0.001) + >> model.fit(train, nb_epoch=50) + + References + ---------- + .. [1] Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, + and Yoshua Bengio. "Graph Attention Networks." ICLR 2018. + + Notes + ----- + This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci + (https://github.com/awslabs/dgl-lifesci) to be installed. + """ def __init__(self, - in_node_dim: int = 30, - hidden_node_dim: int = 32, - heads: int = 1, - dropout: float = 0.0, - num_conv: int = 2, - predictor_hidden_feats: int = 64, - n_tasks: int = 1, + n_tasks: int, + graph_attention_layers: list = None, + n_attention_heads: int = 8, + agg_modes: list = None, + activation=F.elu, + residual: bool = True, + dropout: float = 0., + alpha: float = 0.2, + predictor_hidden_feats: int = 128, + predictor_dropout: float = 0., mode: str = 'regression', + number_atom_features: int = 75, n_classes: int = 2, + nfeat_name: str = 'x', + self_loop: bool = True, **kwargs): """ - This class accepts all the keyword arguments from TorchModel. - - Parameters - ---------- - in_node_dim: int, default 30 - The length of the initial node feature vectors. The 30 is - based on `MolGraphConvFeaturizer`. - hidden_node_dim: int, default 32 - The length of the hidden node feature vectors. - heads: int, default 1 - The number of multi-head-attentions. - dropout: float, default 0.0 - The dropout probability for each convolutional layer. - num_conv: int, default 2 - The number of convolutional layers. - predictor_hidden_feats: int, default 64 - The size for hidden representations in the output MLP predictor, default to 64. - n_tasks: int, default 1 - The number of the output size, default to 1. - mode: str, default 'regression' - The model type, 'classification' or 'regression'. - n_classes: int, default 2 - The number of classes to predict (only used in classification mode). - kwargs: Dict - This class accepts all the keyword arguments from TorchModel. - """ - model = GAT(in_node_dim, hidden_node_dim, heads, dropout, num_conv, - predictor_hidden_feats, n_tasks, mode, n_classes) - if mode == "regression": + Parameters + ---------- + n_tasks: int + Number of tasks. + graph_attention_layers: list of int + Width of channels per attention head for GAT layers. graph_attention_layers[i] + gives the width of channel for each attention head for the i-th GAT layer. If + both ``graph_attention_layers`` and ``agg_modes`` are specified, they should have + equal length. If not specified, the default value will be [8, 8]. + n_attention_heads: int + Number of attention heads in each GAT layer. + agg_modes: list of str + The way to aggregate multi-head attention results for each GAT layer, which can be + either 'flatten' for concatenating all-head results or 'mean' for averaging all-head + results. ``agg_modes[i]`` gives the way to aggregate multi-head attention results for + the i-th GAT layer. If both ``graph_attention_layers`` and ``agg_modes`` are + specified, they should have equal length. If not specified, the model will flatten + multi-head results for intermediate GAT layers and compute mean of multi-head results + for the last GAT layer. + activation: activation function or None + The activation function to apply to the aggregated multi-head results for each GAT + layer. If not specified, the default value will be ELU. + residual: bool + Whether to add a residual connection within each GAT layer. Default to True. + dropout: float + The dropout probability within each GAT layer. Default to 0. + alpha: float + A hyperparameter in LeakyReLU, which is the slope for negative values. Default to 0.2. + predictor_hidden_feats: int + The size for hidden representations in the output MLP predictor. Default to 128. + predictor_dropout: float + The dropout probability in the output MLP predictor. Default to 0. + mode: str + The model type, 'classification' or 'regression'. + number_atom_features: int + The length of the initial atom feature vectors. Default to 75. + n_classes: int + The number of classes to predict per task + (only used when ``mode`` is 'classification'). + nfeat_name: str + For an input graph ``g``, the model assumes that it stores node features in + ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + self_loop: bool + Whether to add self loops for the nodes, i.e. edges from nodes to themselves. + Default to True. + kwargs + This can include any keyword argument of TorchModel. + """ + model = GAT( + n_tasks=n_tasks, + graph_attention_layers=graph_attention_layers, + n_attention_heads=n_attention_heads, + agg_modes=agg_modes, + activation=activation, + residual=residual, + dropout=dropout, + alpha=alpha, + predictor_hidden_feats=predictor_hidden_feats, + predictor_dropout=predictor_dropout, + mode=mode, + number_atom_features=number_atom_features, + n_classes=n_classes, + nfeat_name=nfeat_name) + if mode == 'regression': loss: Loss = L2Loss() output_types = ['prediction'] else: @@ -227,33 +337,38 @@ def __init__(self, super(GATModel, self).__init__( model, loss=loss, output_types=output_types, **kwargs) + self._self_loop = self_loop + def _prepare_batch(self, batch): """Create batch data for GAT. - Parameters - ---------- - batch: Tuple - The tuple are `(inputs, labels, weights)`. - - Returns - ------- - inputs: torch_geometric.data.Batch - A mini-batch graph data for PyTorch Geometric models. - labels: List[torch.Tensor] or None - The labels converted to torch.Tensor. - weights: List[torch.Tensor] or None - The weights for each sample or sample/task pair converted to torch.Tensor. - """ + Parameters + ---------- + batch: tuple + The tuple is ``(inputs, labels, weights)``. + self_loop: bool + Whether to add self loops for the nodes, i.e. edges from nodes + to themselves. Default to False. + + Returns + ------- + inputs: DGLGraph + DGLGraph for a batch of graphs. + labels: list of torch.Tensor or None + The graph labels. + weights: list of torch.Tensor or None + The weights for each sample or sample/task pair converted to torch.Tensor. + """ try: - from torch_geometric.data import Batch + import dgl except: - raise ImportError( - "This class requires PyTorch Geometric to be installed.") + raise ImportError('This class requires dgl.') inputs, labels, weights = batch - pyg_graphs = [graph.to_pyg_graph() for graph in inputs[0]] - inputs = Batch.from_data_list(pyg_graphs) - inputs = inputs.to(self.device) + dgl_graphs = [ + graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0] + ] + inputs = dgl.batch(dgl_graphs).to(self.device) _, labels, weights = super(GATModel, self)._prepare_batch(([], labels, weights)) return inputs, labels, weights diff --git a/deepchem/models/torch_models/gcn.py b/deepchem/models/torch_models/gcn.py index c76668f4b2..74be264fe5 100644 --- a/deepchem/models/torch_models/gcn.py +++ b/deepchem/models/torch_models/gcn.py @@ -302,6 +302,7 @@ def __init__(self, This can include any keyword argument of TorchModel. """ model = GCN( + n_tasks=n_tasks, graph_conv_layers=graph_conv_layers, activation=activation, residual=residual, @@ -309,7 +310,6 @@ def __init__(self, dropout=dropout, predictor_hidden_feats=predictor_hidden_feats, predictor_dropout=predictor_dropout, - n_tasks=n_tasks, mode=mode, number_atom_features=number_atom_features, n_classes=n_classes, From 6bfc32b021f46c919259077c38a75f9168f6f2aa Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 4 Nov 2020 10:33:40 +0000 Subject: [PATCH 2/8] Update --- deepchem/models/tests/test_gat.py | 71 +++++++++++++++-------------- deepchem/models/torch_models/gat.py | 6 ++- 2 files changed, 40 insertions(+), 37 deletions(-) diff --git a/deepchem/models/tests/test_gat.py b/deepchem/models/tests/test_gat.py index b889be6532..51568bd8cb 100644 --- a/deepchem/models/tests/test_gat.py +++ b/deepchem/models/tests/test_gat.py @@ -61,41 +61,42 @@ def test_gat_classification(): scores = model.evaluate(dataset, [metric], transformers) assert scores['mean-roc_auc_score'] >= 0.85 + @unittest.skipIf(not has_torch_and_dgl, 'PyTorch, DGL, or DGL-LifeSci are not installed') def test_gat_reload(): - # load datasets - featurizer = MolGraphConvFeaturizer() - tasks, dataset, transformers, metric = get_dataset( - 'classification', featurizer=featurizer) - - # initialize models - n_tasks = len(tasks) - model_dir = tempfile.mkdtemp() - model = GATModel( - mode='classification', - n_tasks=n_tasks, - number_atom_features=30, - model_dir=model_dir, - batch_size=10, - learning_rate=0.001) - - model.fit(dataset, nb_epoch=50) - scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean-roc_auc_score'] >= 0.85 - - reloaded_model = GATModel( - mode='classification', - n_tasks=n_tasks, - number_atom_features=30, - model_dir=model_dir, - batch_size=10, - learning_rate=0.001) - reloaded_model.restore() - - pred_mols = ["CCCC", "CCCCCO", "CCCCC"] - X_pred = featurizer(pred_mols) - random_dataset = dc.data.NumpyDataset(X_pred) - original_pred = model.predict(random_dataset) - reload_pred = reloaded_model.predict(random_dataset) - assert np.all(original_pred == reload_pred) + # load datasets + featurizer = MolGraphConvFeaturizer() + tasks, dataset, transformers, metric = get_dataset( + 'classification', featurizer=featurizer) + + # initialize models + n_tasks = len(tasks) + model_dir = tempfile.mkdtemp() + model = GATModel( + mode='classification', + n_tasks=n_tasks, + number_atom_features=30, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + + model.fit(dataset, nb_epoch=50) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.85 + + reloaded_model = GATModel( + mode='classification', + n_tasks=n_tasks, + number_atom_features=30, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + reloaded_model.restore() + + pred_mols = ["CCCC", "CCCCCO", "CCCCC"] + X_pred = featurizer(pred_mols) + random_dataset = dc.data.NumpyDataset(X_pred) + original_pred = model.predict(random_dataset) + reload_pred = reloaded_model.predict(random_dataset) + assert np.all(original_pred == reload_pred) diff --git a/deepchem/models/torch_models/gat.py b/deepchem/models/torch_models/gat.py index 833bb3315d..df52586acf 100644 --- a/deepchem/models/torch_models/gat.py +++ b/deepchem/models/torch_models/gat.py @@ -7,6 +7,7 @@ from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy from deepchem.models.torch_models.torch_model import TorchModel + class GAT(nn.Module): """Model for Graph Property Prediction Based on Graph Attention Networks (GAT). @@ -50,6 +51,7 @@ class GAT(nn.Module): This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci) to be installed. """ + def __init__(self, n_tasks: int, graph_attention_layers: list = None, @@ -168,8 +170,7 @@ def __init__(self, activations=activation, n_tasks=out_size, predictor_hidden_feats=predictor_hidden_feats, - predictor_dropout=predictor_dropout - ) + predictor_dropout=predictor_dropout) def forward(self, g): """Predict graph labels @@ -247,6 +248,7 @@ class GATModel(TorchModel): This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci) to be installed. """ + def __init__(self, n_tasks: int, graph_attention_layers: list = None, From 274481b220fb5a492e1f50e45a41be1c81683c23 Mon Sep 17 00:00:00 2001 From: mufeili Date: Wed, 4 Nov 2020 18:37:49 +0800 Subject: [PATCH 3/8] Update --- deepchem/models/tests/test_gat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepchem/models/tests/test_gat.py b/deepchem/models/tests/test_gat.py index 51568bd8cb..e029d15d2b 100644 --- a/deepchem/models/tests/test_gat.py +++ b/deepchem/models/tests/test_gat.py @@ -57,7 +57,7 @@ def test_gat_classification(): learning_rate=0.001) # overfit test - model.fit(dataset, nb_epoch=50) + model.fit(dataset, nb_epoch=60) scores = model.evaluate(dataset, [metric], transformers) assert scores['mean-roc_auc_score'] >= 0.85 @@ -81,7 +81,7 @@ def test_gat_reload(): batch_size=10, learning_rate=0.001) - model.fit(dataset, nb_epoch=50) + model.fit(dataset, nb_epoch=60) scores = model.evaluate(dataset, [metric], transformers) assert scores['mean-roc_auc_score'] >= 0.85 From 9e5c9fd9dbab2b1cd1c6488593f4a2e3f0723cf4 Mon Sep 17 00:00:00 2001 From: mufeili Date: Thu, 5 Nov 2020 02:53:41 +0800 Subject: [PATCH 4/8] Update --- deepchem/models/__init__.py | 1 + deepchem/models/torch_models/__init__.py | 1 + deepchem/models/torch_models/attentivefp.py | 321 ++++++++++++++++++++ deepchem/models/torch_models/gat.py | 22 +- deepchem/models/torch_models/gcn.py | 22 +- 5 files changed, 347 insertions(+), 20 deletions(-) create mode 100644 deepchem/models/torch_models/attentivefp.py diff --git a/deepchem/models/__init__.py b/deepchem/models/__init__.py index d5c11877f3..d034e2a5aa 100644 --- a/deepchem/models/__init__.py +++ b/deepchem/models/__init__.py @@ -31,6 +31,7 @@ # PyTorch models try: from deepchem.models.torch_models import TorchModel + from deepchem.models.torch_models import AttentiveFP, AttentiveFPModel from deepchem.models.torch_models import CGCNN, CGCNNModel from deepchem.models.torch_models import GAT, GATModel from deepchem.models.torch_models import GCN, GCNModel diff --git a/deepchem/models/torch_models/__init__.py b/deepchem/models/torch_models/__init__.py index 7c2ab1b22a..611ede7000 100644 --- a/deepchem/models/torch_models/__init__.py +++ b/deepchem/models/torch_models/__init__.py @@ -1,5 +1,6 @@ # flake8:noqa from deepchem.models.torch_models.torch_model import TorchModel +from deepchem.models.torch_models.attentivefp import AttentiveFP, AttentiveFPModel from deepchem.models.torch_models.cgcnn import CGCNN, CGCNNModel from deepchem.models.torch_models.gat import GAT, GATModel from deepchem.models.torch_models.gcn import GCN, GCNModel diff --git a/deepchem/models/torch_models/attentivefp.py b/deepchem/models/torch_models/attentivefp.py new file mode 100644 index 0000000000..be61d1157e --- /dev/null +++ b/deepchem/models/torch_models/attentivefp.py @@ -0,0 +1,321 @@ +""" +DGL-based AttentiveFP for graph property prediction. +""" +import torch.nn as nn +import torch.nn.functional as F + +from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy +from deepchem.models.torch_models.torch_model import TorchModel + +class AttentiveFP(nn.Module): + """Model for Graph Property Prediction. + + This model proceeds as follows: + + * Combine node features and edge features for initializing node representations, + which involves a round of message passing + * Update node representations with multiple rounds of message passing + * For each graph, compute its representation by combining the representations + of all nodes in it, which involves a gated recurrent unit (GRU). + * Perform the final prediction using a linear layer + + Examples + -------- + + >>> import deepchem as dc + >>> import dgl + >>> from deepchem.models import AttentiveFP + >>> smiles = ["C1CCC1", "C1=CC=CN=C1"] + >>> featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True) + >>> graphs = featurizer.featurize(smiles) + >>> print(type(graphs[0])) + + >>> dgl_graphs = [graphs[i].to_dgl_graph() for i in range(len(graphs))] + >>> # Batch two graphs into a graph of two connected components + >>> batch_dgl_graph = dgl.batch(dgl_graphs) + >>> model = AttentiveFP(n_tasks=1, mode='regression') + >>> preds = model(batch_dgl_graph) + >>> print(type(preds)) + + >>> preds.shape == (2, 1) + True + + References + ---------- + .. [1] Zhaoping Xiong, Dingyan Wang, Xiaohong Liu, Feisheng Zhong, Xiaozhe Wan, Xutong Li, + Zhaojun Li, Xiaomin Luo, Kaixian Chen, Hualiang Jiang, and Mingyue Zheng. "Pushing + the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention + Mechanism." Journal of Medicinal Chemistry. 2020, 63, 16, 8749–8760. + + Notes + ----- + This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci + (https://github.com/awslabs/dgl-lifesci) to be installed. + """ + + def __init__(self, + n_tasks: int, + num_layers: int = 2, + num_timesteps: int = 2, + graph_feat_size: int = 200, + dropout: float = 0., + mode: str = 'regression', + number_atom_features: int = 30, + number_bond_features: int = 11, + n_classes: int = 2, + nfeat_name: str = 'x', + efeat_name: str = 'edge_attr'): + """ + Parameters + ---------- + n_tasks: int + Number of tasks. + num_layers: int + Number of graph neural network layers, i.e. number of rounds of message passing. + Default to 2. + num_timesteps: int + Number of time steps for updating graph representations with a GRU. Default to 2. + graph_feat_size: int + Size for graph representations. Default to 200. + dropout: float + Dropout probability. Default to 0. + mode: str + The model type, 'classification' or 'regression'. Default to 'regression'. + number_atom_features: int + The length of the initial atom feature vectors. Default to 30. + number_bond_features: int + The length of the initial bond feature vectors. Default to 11. + n_classes: int + The number of classes to predict per task + (only used when ``mode`` is 'classification'). Default to 2. + nfeat_name: str + For an input graph ``g``, the model assumes that it stores node features in + ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. + efeat_name: str + For an input graph ``g``, the model assumes that it stores edge features in + ``g.edata[efeat_name]`` and will retrieve input edge features from that. + Default to 'edge_attr'. + """ + try: + import dgl + except: + raise ImportError('This class requires dgl.') + try: + import dgllife + except: + raise ImportError('This class requires dgllife.') + + if mode not in ['classification', 'regression']: + raise ValueError("mode must be either 'classification' or 'regression'") + + super(AttentiveFP, self).__init__() + + self.n_tasks = n_tasks + self.mode = mode + self.n_classes = n_classes + self.nfeat_name = nfeat_name + self.efeat_name = efeat_name + if mode == 'classification': + out_size = n_tasks * n_classes + else: + out_size = n_tasks + + from dgllife.model import AttentiveFPPredictor as DGLAttentiveFPPredictor + + self.model = DGLAttentiveFPPredictor(node_feat_size=number_atom_features, + edge_feat_size=number_bond_features, + num_layers=num_layers, + num_timesteps=num_timesteps, + graph_feat_size=graph_feat_size, + n_tasks=out_size, + dropout=dropout) + + def forward(self, g): + """Predict graph labels + + Parameters + ---------- + g: DGLGraph + A DGLGraph for a batch of graphs. It stores the node features in + ``dgl_graph.ndata[self.nfeat_name]`` and edge features in + ``dgl_graph.edata[self.efeat_name]``. + + Returns + ------- + torch.Tensor + The model output. + + * When self.mode = 'regression', + its shape will be ``(dgl_graph.batch_size, self.n_tasks)``. + * When self.mode = 'classification', the output consists of probabilities + for classes. Its shape will be + ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)`` if self.n_tasks > 1; + its shape will be ``(dgl_graph.batch_size, self.n_classes)`` if self.n_tasks is 1. + torch.Tensor, optional + This is only returned when self.mode = 'classification', the output consists of the + logits for classes before softmax. + """ + node_feats = g.ndata[self.nfeat_name] + edge_feats = g.edata[self.efeat_name] + out = self.model(g, node_feats, edge_feats) + + if self.mode == 'classification': + if self.n_tasks == 1: + logits = out.view(-1, self.n_classes) + softmax_dim = 1 + else: + logits = out.view(-1, self.n_tasks, self.n_classes) + softmax_dim = 2 + proba = F.softmax(logits, dim=softmax_dim) + return proba, logits + else: + return out + + +class AttentiveFPModel(TorchModel): + """Model for Graph Property Prediction. + + This model proceeds as follows: + + * Combine node features and edge features for initializing node representations, + which involves a round of message passing + * Update node representations with multiple rounds of message passing + * For each graph, compute its representation by combining the representations + of all nodes in it, which involves a gated recurrent unit (GRU). + * Perform the final prediction using a linear layer + + Examples + -------- + + >>> + >> import deepchem as dc + >> from deepchem.models import AttentiveFPModel + >> featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True) + >> tasks, datasets, transformers = dc.molnet.load_tox21( + .. reload=False, featurizer=featurizer, transformers=[]) + >> train, valid, test = datasets + >> model = dc.models.AttentiveFPModel(mode='classification', n_tasks=len(tasks), + .. batch_size=32, learning_rate=0.001) + >> model.fit(train, nb_epoch=50) + + References + ---------- + .. [1] Zhaoping Xiong, Dingyan Wang, Xiaohong Liu, Feisheng Zhong, Xiaozhe Wan, Xutong Li, + Zhaojun Li, Xiaomin Luo, Kaixian Chen, Hualiang Jiang, and Mingyue Zheng. "Pushing + the Boundaries of Molecular Representation for Drug Discovery with the Graph + Attention Mechanism." Journal of Medicinal Chemistry. 2020, 63, 16, 8749–8760. + + Notes + ----- + This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci + (https://github.com/awslabs/dgl-lifesci) to be installed. + """ + + def __init__(self, + n_tasks: int, + num_layers: int = 2, + num_timesteps: int = 2, + graph_feat_size: int = 200, + dropout: float = 0., + mode: str = 'regression', + number_atom_features: int = 30, + number_bond_features: int = 11, + n_classes: int = 2, + nfeat_name: str = 'x', + efeat_name: str = 'edge_attr', + self_loop: bool = True, + **kwargs): + """ + Parameters + ---------- + n_tasks: int + Number of tasks. + num_layers: int + Number of graph neural network layers, i.e. number of rounds of message passing. + Default to 2. + num_timesteps: int + Number of time steps for updating graph representations with a GRU. Default to 2. + graph_feat_size: int + Size for graph representations. Default to 200. + dropout: float + Dropout probability. Default to 0. + mode: str + The model type, 'classification' or 'regression'. Default to 'regression'. + number_atom_features: int + The length of the initial atom feature vectors. Default to 30. + number_bond_features: int + The length of the initial bond feature vectors. Default to 11. + n_classes: int + The number of classes to predict per task + (only used when ``mode`` is 'classification'). Default to 2. + nfeat_name: str + For an input graph ``g``, the model assumes that it stores node features in + ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. + efeat_name: str + For an input graph ``g``, the model assumes that it stores edge features in + ``g.edata[efeat_name]`` and will retrieve input edge features from that. + Default to 'edge_attr'. + self_loop: bool + Whether to add self loops for the nodes, i.e. edges from nodes to themselves. + Default to True. + kwargs + This can include any keyword argument of TorchModel. + """ + model = AttentiveFP( + n_tasks=n_tasks, + num_layers=num_layers, + num_timesteps=num_timesteps, + graph_feat_size=graph_feat_size, + dropout=dropout, + mode=mode, + number_atom_features=number_atom_features, + number_bond_features=number_bond_features, + n_classes=n_classes, + nfeat_name=nfeat_name, + efeat_name=efeat_name) + if mode == 'regression': + loss: Loss = L2Loss() + output_types = ['prediction'] + else: + loss = SparseSoftmaxCrossEntropy() + output_types = ['prediction', 'loss'] + super(AttentiveFPModel, self).__init__( + model, loss=loss, output_types=output_types, **kwargs) + + self._self_loop = self_loop + + def _prepare_batch(self, batch): + """Create batch data for AttentiveFP. + + Parameters + ---------- + batch: tuple + The tuple is ``(inputs, labels, weights)``. + self_loop: bool + Whether to add self loops for the nodes, i.e. edges from nodes + to themselves. Default to False. + + Returns + ------- + inputs: DGLGraph + DGLGraph for a batch of graphs. + labels: list of torch.Tensor or None + The graph labels. + weights: list of torch.Tensor or None + The weights for each sample or sample/task pair converted to torch.Tensor. + """ + try: + import dgl + except: + raise ImportError('This class requires dgl.') + + inputs, labels, weights = batch + dgl_graphs = [ + graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0] + ] + inputs = dgl.batch(dgl_graphs).to(self.device) + _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(([], labels, + weights)) + return inputs, labels, weights diff --git a/deepchem/models/torch_models/gat.py b/deepchem/models/torch_models/gat.py index df52586acf..e0cba8caf3 100644 --- a/deepchem/models/torch_models/gat.py +++ b/deepchem/models/torch_models/gat.py @@ -34,7 +34,7 @@ class GAT(nn.Module): >>> dgl_graphs = [graphs[i].to_dgl_graph() for i in range(len(graphs))] >>> # Batch two graphs into a graph of two connected components >>> batch_dgl_graph = dgl.batch(dgl_graphs) - >>> model = GAT(n_tasks=1, number_atom_features=30, mode='regression') + >>> model = GAT(n_tasks=1, mode='regression') >>> preds = model(batch_dgl_graph) >>> print(type(preds)) @@ -64,7 +64,7 @@ def __init__(self, predictor_hidden_feats: int = 128, predictor_dropout: float = 0., mode: str = 'regression', - number_atom_features: int = 75, + number_atom_features: int = 30, n_classes: int = 2, nfeat_name: str = 'x'): """ @@ -101,15 +101,16 @@ def __init__(self, predictor_dropout: float The dropout probability in the output MLP predictor. Default to 0. mode: str - The model type, 'classification' or 'regression'. + The model type, 'classification' or 'regression'. Default to 'regression'. number_atom_features: int - The length of the initial atom feature vectors. Default to 75. + The length of the initial atom feature vectors. Default to 30. n_classes: int The number of classes to predict per task - (only used when ``mode`` is 'classification'). + (only used when ``mode`` is 'classification'). Default to 2. nfeat_name: str For an input graph ``g``, the model assumes that it stores node features in ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. """ try: import dgl @@ -235,7 +236,7 @@ class GATModel(TorchModel): .. reload=False, featurizer=featurizer, transformers=[]) >> train, valid, test = datasets >> model = dc.models.GATModel(mode='classification', n_tasks=len(tasks), - .. number_atom_features=30, batch_size=32, learning_rate=0.001) + .. batch_size=32, learning_rate=0.001) >> model.fit(train, nb_epoch=50) References @@ -261,7 +262,7 @@ def __init__(self, predictor_hidden_feats: int = 128, predictor_dropout: float = 0., mode: str = 'regression', - number_atom_features: int = 75, + number_atom_features: int = 30, n_classes: int = 2, nfeat_name: str = 'x', self_loop: bool = True, @@ -300,15 +301,16 @@ def __init__(self, predictor_dropout: float The dropout probability in the output MLP predictor. Default to 0. mode: str - The model type, 'classification' or 'regression'. + The model type, 'classification' or 'regression'. Default to 'regression'. number_atom_features: int - The length of the initial atom feature vectors. Default to 75. + The length of the initial atom feature vectors. Default to 30. n_classes: int The number of classes to predict per task - (only used when ``mode`` is 'classification'). + (only used when ``mode`` is 'classification'). Default to 2. nfeat_name: str For an input graph ``g``, the model assumes that it stores node features in ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. self_loop: bool Whether to add self loops for the nodes, i.e. edges from nodes to themselves. Default to True. diff --git a/deepchem/models/torch_models/gcn.py b/deepchem/models/torch_models/gcn.py index 74be264fe5..26d0168e32 100644 --- a/deepchem/models/torch_models/gcn.py +++ b/deepchem/models/torch_models/gcn.py @@ -34,7 +34,7 @@ class GCN(nn.Module): >>> dgl_graphs = [graphs[i].to_dgl_graph() for i in range(len(graphs))] >>> # Batch two graphs into a graph of two connected components >>> batch_dgl_graph = dgl.batch(dgl_graphs) - >>> model = GCN(n_tasks=1, number_atom_features=30, mode='regression') + >>> model = GCN(n_tasks=1, mode='regression') >>> preds = model(batch_dgl_graph) >>> print(type(preds)) @@ -77,7 +77,7 @@ def __init__(self, predictor_hidden_feats: int = 128, predictor_dropout: float = 0., mode: str = 'regression', - number_atom_features: int = 75, + number_atom_features: int = 30, n_classes: int = 2, nfeat_name: str = 'x'): """ @@ -103,15 +103,16 @@ def __init__(self, predictor_dropout: float The dropout probability in the output MLP predictor. Default to 0. mode: str - The model type, 'classification' or 'regression'. + The model type, 'classification' or 'regression'. Default to 'regression'. number_atom_features: int - The length of the initial atom feature vectors. Default to 75. + The length of the initial atom feature vectors. Default to 30. n_classes: int The number of classes to predict per task - (only used when ``mode`` is 'classification'). + (only used when ``mode`` is 'classification'). Default to 2. nfeat_name: str For an input graph ``g``, the model assumes that it stores node features in ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. """ try: import dgl @@ -219,7 +220,7 @@ class GCNModel(TorchModel): .. reload=False, featurizer=featurizer, transformers=[]) >> train, valid, test = datasets >> model = dc.models.GCNModel(mode='classification', n_tasks=len(tasks), - .. number_atom_features=30, batch_size=32, learning_rate=0.001) + .. batch_size=32, learning_rate=0.001) >> model.fit(train, nb_epoch=50) References @@ -258,7 +259,7 @@ def __init__(self, predictor_hidden_feats: int = 128, predictor_dropout: float = 0., mode: str = 'regression', - number_atom_features=75, + number_atom_features=30, n_classes: int = 2, nfeat_name: str = 'x', self_loop: bool = True, @@ -286,15 +287,16 @@ def __init__(self, predictor_dropout: float The dropout probability in the output MLP predictor. Default to 0. mode: str - The model type, 'classification' or 'regression'. + The model type, 'classification' or 'regression'. Default to 'regression'. number_atom_features: int - The length of the initial atom feature vectors. Default to 75. + The length of the initial atom feature vectors. Default to 30. n_classes: int The number of classes to predict per task - (only used when ``mode`` is 'classification'). + (only used when ``mode`` is 'classification'). Default to 2. nfeat_name: str For an input graph ``g``, the model assumes that it stores node features in ``g.ndata[nfeat_name]`` and will retrieve input node features from that. + Default to 'x'. self_loop: bool Whether to add self loops for the nodes, i.e. edges from nodes to themselves. Default to True. From cb58af3406be89206e39f9f0f24c877756891500 Mon Sep 17 00:00:00 2001 From: mufeili Date: Thu, 5 Nov 2020 03:00:53 +0800 Subject: [PATCH 5/8] Update --- deepchem/feat/graph_data.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepchem/feat/graph_data.py b/deepchem/feat/graph_data.py index 666600af74..bbfd74be13 100644 --- a/deepchem/feat/graph_data.py +++ b/deepchem/feat/graph_data.py @@ -146,9 +146,6 @@ def to_dgl_graph(self, self_loop: bool = False): src = self.edge_index[0] dst = self.edge_index[1] - if self_loop: - src = np.concatenate([src, np.arange(self.num_nodes)]) - dst = np.concatenate([dst, np.arange(self.num_nodes)]) g = dgl.graph( (torch.from_numpy(src).long(), torch.from_numpy(dst).long()), @@ -161,6 +158,11 @@ def to_dgl_graph(self, self_loop: bool = False): if self.edge_features is not None: g.edata['edge_attr'] = torch.from_numpy(self.edge_features).float() + if self_loop: + # This assumes that the edge features for self loops are full-zero tensors + # In the future we may want to support featurization for self loops + g.add_edges(np.arange(self.num_nodes), np.arange(self.num_nodes)) + return g From 30ed432555248a6b63f71b18e1cebdcc36dd3e5a Mon Sep 17 00:00:00 2001 From: mufeili Date: Thu, 5 Nov 2020 03:13:29 +0800 Subject: [PATCH 6/8] Update --- deepchem/models/tests/test_attentivefp.py | 98 +++++++++++++++++++++++ docs/models.rst | 9 +++ 2 files changed, 107 insertions(+) create mode 100644 deepchem/models/tests/test_attentivefp.py diff --git a/deepchem/models/tests/test_attentivefp.py b/deepchem/models/tests/test_attentivefp.py new file mode 100644 index 0000000000..1a7cdbec54 --- /dev/null +++ b/deepchem/models/tests/test_attentivefp.py @@ -0,0 +1,98 @@ +import unittest +import tempfile + +import numpy as np + +import deepchem as dc +from deepchem.feat import MolGraphConvFeaturizer +from deepchem.models import AttentiveFPModel +from deepchem.models.tests.test_graph_models import get_dataset + +try: + import dgl + import dgllife + import torch + has_torch_and_dgl = True +except: + has_torch_and_dgl = False + + +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') +def test_attentivefp_regression(): + # load datasets + featurizer = MolGraphConvFeaturizer(use_edges=True) + tasks, dataset, transformers, metric = get_dataset( + 'regression', featurizer=featurizer) + + # initialize models + n_tasks = len(tasks) + model = AttentiveFPModel( + mode='regression', + n_tasks=n_tasks, + batch_size=10) + + # overfit test + model.fit(dataset, nb_epoch=100) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean_absolute_error'] < 0.5 + + +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') +def test_attentivefp_classification(): + # load datasets + featurizer = MolGraphConvFeaturizer(use_edges=True) + tasks, dataset, transformers, metric = get_dataset( + 'classification', featurizer=featurizer) + + # initialize models + n_tasks = len(tasks) + model = AttentiveFPModel( + mode='classification', + n_tasks=n_tasks, + batch_size=10, + learning_rate=0.001) + + # overfit test + model.fit(dataset, nb_epoch=60) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.85 + + +@unittest.skipIf(not has_torch_and_dgl, + 'PyTorch, DGL, or DGL-LifeSci are not installed') +def test_attentivefp_reload(): + # load datasets + featurizer = MolGraphConvFeaturizer(use_edges=True) + tasks, dataset, transformers, metric = get_dataset( + 'classification', featurizer=featurizer) + + # initialize models + n_tasks = len(tasks) + model_dir = tempfile.mkdtemp() + model = AttentiveFPModel( + mode='classification', + n_tasks=n_tasks, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + + model.fit(dataset, nb_epoch=60) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.85 + + reloaded_model = AttentiveFPModel( + mode='classification', + n_tasks=n_tasks, + model_dir=model_dir, + batch_size=10, + learning_rate=0.001) + reloaded_model.restore() + + pred_mols = ["CCCC", "CCCCCO", "CCCCC"] + X_pred = featurizer(pred_mols) + random_dataset = dc.data.NumpyDataset(X_pred) + original_pred = model.predict(random_dataset) + reload_pred = reloaded_model.predict(random_dataset) + assert np.all(original_pred == reload_pred) diff --git a/docs/models.rst b/docs/models.rst index ae6706e54d..4c692b086b 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -129,6 +129,9 @@ read off what's needed to train the model from the table below. | :code:`GCNModel` | Classifier/| :code:`GraphData` | | :code:`MolGraphConvFeaturizer` | :code:`fit` | | | Regressor | | | | | +----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+ +| :code:`AttentiveFPModel` | Classifier/| :code:`GraphData` | | :code:`MolGraphConvFeaturizer` | :code:`fit` | +| | Regressor | | | | | ++----------------------------------------+------------+----------------------+------------------------+----------------------------------------------------------------+----------------------+ Model ----- @@ -456,3 +459,9 @@ GCNModel .. autoclass:: deepchem.models.GCNModel :members: + +AttentiveFPModel +---------------- + +.. autoclass:: deepchem.models.AttentiveFPModel + :members: From 65d3b190526ee57ec47159fb9834cf266fa9c1c9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 4 Nov 2020 19:18:08 +0000 Subject: [PATCH 7/8] Update --- deepchem/models/tests/test_attentivefp.py | 5 +- deepchem/models/torch_models/attentivefp.py | 124 ++++++++++---------- 2 files changed, 64 insertions(+), 65 deletions(-) diff --git a/deepchem/models/tests/test_attentivefp.py b/deepchem/models/tests/test_attentivefp.py index 1a7cdbec54..49b179b9eb 100644 --- a/deepchem/models/tests/test_attentivefp.py +++ b/deepchem/models/tests/test_attentivefp.py @@ -27,10 +27,7 @@ def test_attentivefp_regression(): # initialize models n_tasks = len(tasks) - model = AttentiveFPModel( - mode='regression', - n_tasks=n_tasks, - batch_size=10) + model = AttentiveFPModel(mode='regression', n_tasks=n_tasks, batch_size=10) # overfit test model.fit(dataset, nb_epoch=100) diff --git a/deepchem/models/torch_models/attentivefp.py b/deepchem/models/torch_models/attentivefp.py index be61d1157e..1447ab7ebc 100644 --- a/deepchem/models/torch_models/attentivefp.py +++ b/deepchem/models/torch_models/attentivefp.py @@ -7,6 +7,7 @@ from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy from deepchem.models.torch_models.torch_model import TorchModel + class AttentiveFP(nn.Module): """Model for Graph Property Prediction. @@ -123,13 +124,14 @@ def __init__(self, from dgllife.model import AttentiveFPPredictor as DGLAttentiveFPPredictor - self.model = DGLAttentiveFPPredictor(node_feat_size=number_atom_features, - edge_feat_size=number_bond_features, - num_layers=num_layers, - num_timesteps=num_timesteps, - graph_feat_size=graph_feat_size, - n_tasks=out_size, - dropout=dropout) + self.model = DGLAttentiveFPPredictor( + node_feat_size=number_atom_features, + edge_feat_size=number_bond_features, + num_layers=num_layers, + num_timesteps=num_timesteps, + graph_feat_size=graph_feat_size, + n_tasks=out_size, + dropout=dropout) def forward(self, g): """Predict graph labels @@ -174,7 +176,7 @@ def forward(self, g): class AttentiveFPModel(TorchModel): - """Model for Graph Property Prediction. + """Model for Graph Property Prediction. This model proceeds as follows: @@ -212,21 +214,21 @@ class AttentiveFPModel(TorchModel): (https://github.com/awslabs/dgl-lifesci) to be installed. """ - def __init__(self, - n_tasks: int, - num_layers: int = 2, - num_timesteps: int = 2, - graph_feat_size: int = 200, - dropout: float = 0., - mode: str = 'regression', - number_atom_features: int = 30, - number_bond_features: int = 11, - n_classes: int = 2, - nfeat_name: str = 'x', - efeat_name: str = 'edge_attr', - self_loop: bool = True, - **kwargs): - """ + def __init__(self, + n_tasks: int, + num_layers: int = 2, + num_timesteps: int = 2, + graph_feat_size: int = 200, + dropout: float = 0., + mode: str = 'regression', + number_atom_features: int = 30, + number_bond_features: int = 11, + n_classes: int = 2, + nfeat_name: str = 'x', + efeat_name: str = 'edge_attr', + self_loop: bool = True, + **kwargs): + """ Parameters ---------- n_tasks: int @@ -263,31 +265,31 @@ def __init__(self, kwargs This can include any keyword argument of TorchModel. """ - model = AttentiveFP( - n_tasks=n_tasks, - num_layers=num_layers, - num_timesteps=num_timesteps, - graph_feat_size=graph_feat_size, - dropout=dropout, - mode=mode, - number_atom_features=number_atom_features, - number_bond_features=number_bond_features, - n_classes=n_classes, - nfeat_name=nfeat_name, - efeat_name=efeat_name) - if mode == 'regression': - loss: Loss = L2Loss() - output_types = ['prediction'] - else: - loss = SparseSoftmaxCrossEntropy() - output_types = ['prediction', 'loss'] - super(AttentiveFPModel, self).__init__( - model, loss=loss, output_types=output_types, **kwargs) - - self._self_loop = self_loop - - def _prepare_batch(self, batch): - """Create batch data for AttentiveFP. + model = AttentiveFP( + n_tasks=n_tasks, + num_layers=num_layers, + num_timesteps=num_timesteps, + graph_feat_size=graph_feat_size, + dropout=dropout, + mode=mode, + number_atom_features=number_atom_features, + number_bond_features=number_bond_features, + n_classes=n_classes, + nfeat_name=nfeat_name, + efeat_name=efeat_name) + if mode == 'regression': + loss: Loss = L2Loss() + output_types = ['prediction'] + else: + loss = SparseSoftmaxCrossEntropy() + output_types = ['prediction', 'loss'] + super(AttentiveFPModel, self).__init__( + model, loss=loss, output_types=output_types, **kwargs) + + self._self_loop = self_loop + + def _prepare_batch(self, batch): + """Create batch data for AttentiveFP. Parameters ---------- @@ -306,16 +308,16 @@ def _prepare_batch(self, batch): weights: list of torch.Tensor or None The weights for each sample or sample/task pair converted to torch.Tensor. """ - try: - import dgl - except: - raise ImportError('This class requires dgl.') - - inputs, labels, weights = batch - dgl_graphs = [ - graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0] - ] - inputs = dgl.batch(dgl_graphs).to(self.device) - _, labels, weights = super(AttentiveFPModel, self)._prepare_batch(([], labels, - weights)) - return inputs, labels, weights + try: + import dgl + except: + raise ImportError('This class requires dgl.') + + inputs, labels, weights = batch + dgl_graphs = [ + graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0] + ] + inputs = dgl.batch(dgl_graphs).to(self.device) + _, labels, weights = super(AttentiveFPModel, self)._prepare_batch( + ([], labels, weights)) + return inputs, labels, weights From 18191211a576ec5ed42846e29e9f84a826e3b727 Mon Sep 17 00:00:00 2001 From: mufeili Date: Thu, 5 Nov 2020 03:34:52 +0800 Subject: [PATCH 8/8] Update --- docs/models.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/models.rst b/docs/models.rst index 4c692b086b..68138fac75 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -202,9 +202,6 @@ Losses .. autoclass:: deepchem.models.losses.SparseSoftmaxCrossEntropy :members: -.. autoclass:: deepchem.models.losses.SparseSoftmaxCrossEntropy - :members: - .. autoclass:: deepchem.models.losses.VAE_ELBO :members: @@ -244,9 +241,6 @@ Optimizers .. autoclass:: deepchem.models.optimizers.LinearCosineDecay :members: -.. autoclass:: deepchem.models.optimizers.LinearCosineDecay - :members: - Keras Models ============