-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- Modifiying One of the following AttentiveFPModel Parameters:
Default: number_bond_features=11,
Default: number_atom_features=30,
- Fitting the model -- Using modified example from here: deepchem AttentiveFPModel Example
from deepchem.models import AttentiveFPModel
smiles = ["C1CCC1", "C1=CC=CN=C1"]
labels = [0., 1.]
featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
X = featurizer.featurize(smiles)
dataset = dc.data.NumpyDataset(X=X, y=labels)
# training model
model = AttentiveFPModel(mode='classification',
n_tasks=1,
batch_size=16,
number_bond_features=12,
learning_rate=0.001)
loss = model.fit(dataset, nb_epoch=5)
Traceback (most recent call last):
File "/home/matt/test_script/attentive_dc_test.py", line 16, in <module>
loss = model.fit(dataset, nb_epoch=5)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/torch_model.py", line 334, in fit
return self.fit_generator(
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/torch_model.py", line 424, in fit_generator
outputs = self.model(inputs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/deepchem/models/torch_models/attentivefp.py", line 163, in forward
out = self.model(g, node_feats, edge_feats)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/model_zoo/attentivefp_predictor.py", line 87, in forward
node_feats = self.gnn(g, node_feats, edge_feats)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/gnn/attentivefp.py", line 357, in forward
node_feats = self.init_context(g, node_feats, edge_feats)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/dgllife/model/gnn/attentivefp.py", line 223, in forward
g.ndata['hv_new'] = self.project_node(node_feats)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/matt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (80x30 and 31x200)
Expected behavior
Fits the model without error, returning loss.
Environment
- OS: Red Hat Enterprise Linux 8.7 (Ootpa)
- Python version: 3.9.12
- DeepChem version: 2.6.1
- RDKit version (optional): 2022.9.1
- TensorFlow version (optional):
- PyTorch version (optional): 1.12.1
- Any other relevant information:
Additional context
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels