Skip to content

AttentiveFPModel -- Cannot change number_bond_features or number_atom_features parameter value without error while running model.fit() #3133

@mattaadams

Description

@mattaadams

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Modifiying One of the following AttentiveFPModel Parameters:

Default: number_bond_features=11,
Default: number_atom_features=30,

  1. Fitting the model -- Using modified example from here: deepchem AttentiveFPModel Example
from deepchem.models import AttentiveFPModel
smiles = ["C1CCC1", "C1=CC=CN=C1"]
labels = [0., 1.]
featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
X = featurizer.featurize(smiles)
dataset = dc.data.NumpyDataset(X=X, y=labels)
# training model
model = AttentiveFPModel(mode='classification',
                         n_tasks=1,
                        batch_size=16, 
                        number_bond_features=12,
                        learning_rate=0.001)
loss = model.fit(dataset, nb_epoch=5)
Traceback (most recent call last):
  File "/home/matt/test_script/attentive_dc_test.py", line 16, in <module>
    loss = model.fit(dataset, nb_epoch=5)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/torch_model.py", line 334, in fit
    return self.fit_generator(
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/torch_model.py", line 424, in fit_generator
    outputs = self.model(inputs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/attentivefp.py", line 163, in forward
    out = self.model(g, node_feats, edge_feats)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/model_zoo/attentivefp_predictor.py", line 87, in forward
    node_feats = self.gnn(g, node_feats, edge_feats)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/gnn/attentivefp.py", line 357, in forward
    node_feats = self.init_context(g, node_feats, edge_feats)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/gnn/attentivefp.py", line 223, in forward
    g.ndata['hv_new'] = self.project_node(node_feats)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (80x30 and 31x200)

Expected behavior

Fits the model without error, returning loss.

Environment

  • OS: Red Hat Enterprise Linux 8.7 (Ootpa)
  • Python version: 3.9.12
  • DeepChem version: 2.6.1
  • RDKit version (optional): 2022.9.1
  • TensorFlow version (optional):
  • PyTorch version (optional): 1.12.1
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions