Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding weighted_skip parameter to MutilayerPerceptron layer #3494

Merged
merged 3 commits into from Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 23 additions & 2 deletions deepchem/models/tests/test_layers.py
Expand Up @@ -664,7 +664,7 @@ def test_position_wise_feed_forward():
(True, False, [[-0.9612, 2.3846], [-4.1104, 5.7606]]),
(False, True, [[0.2795, 0.4243], [0.2795, 0.4243]]),
(True, True, [[-0.9612, 2.3846], [-4.1104, 5.7606]])])
def test_MultilayerPerceptron(skip_connection, batch_norm, expected):
def test_multilayer_perceptron(skip_connection, batch_norm, expected):
"""Test invoking MLP."""
torch.manual_seed(0)
input_ar = torch.tensor([[1., 2.], [5., 6.]])
Expand All @@ -681,7 +681,7 @@ def test_MultilayerPerceptron(skip_connection, batch_norm, expected):


@pytest.mark.torch
def test_MultilayerPerceptron_overfit():
def test_multilayer_perceptron_overfit():
import torch
import deepchem.models.torch_models.layers as torch_layers
from deepchem.data import NumpyDataset
Expand All @@ -703,6 +703,27 @@ def test_MultilayerPerceptron_overfit():
assert np.allclose(output, y, atol=1e-2)


@pytest.mark.torch
def test_weighted_skip_multilayer_perceptron():
"Test for weighted skip connection from the input to the output"
seed = 123
torch.manual_seed(seed)
dim = 1
features = torch.Tensor([[0.8343], [1.2713], [1.2713], [1.2713], [1.2713]])
layer = dc.models.torch_models.layers.MultilayerPerceptron(
d_input=dim,
d_hidden=(dim,),
d_output=dim,
activation_fn='silu',
skip_connection=True,
weighted_skip=False)
output = layer(features)
output = output.detach().numpy()
result = np.array([[1.1032], [1.5598], [1.5598], [1.5598], [1.5598]])
assert np.allclose(output, result, atol=1e-04)
assert output.shape == (5, 1)


@pytest.mark.torch
def test_position_wise_feed_forward_dropout_at_input():
"""Test invoking PositionwiseFeedForward."""
Expand Down
11 changes: 9 additions & 2 deletions deepchem/models/torch_models/layers.py
Expand Up @@ -40,7 +40,8 @@ def __init__(self,
batch_norm: bool = False,
batch_norm_momentum: float = 0.1,
activation_fn: Union[Callable, str] = 'relu',
skip_connection: bool = False):
skip_connection: bool = False,
weighted_skip: bool = True):
"""Initialize the model.

Parameters
Expand All @@ -61,6 +62,8 @@ def __init__(self,
the activation function to use in the hidden layers
skip_connection: bool
whether to add a skip connection from the input to the output
weighted_skip: bool
whether to add a weighted skip connection from the input to the output
"""
super(MultilayerPerceptron, self).__init__()
self.d_input = d_input
Expand All @@ -72,6 +75,7 @@ def __init__(self,
self.activation_fn = get_activation(activation_fn)
self.model = nn.Sequential(*self.build_layers())
self.skip = nn.Linear(d_input, d_output) if skip_connection else None
self.weighted_skip = weighted_skip

def build_layers(self):
"""
Expand Down Expand Up @@ -101,7 +105,10 @@ def forward(self, x: Tensor) -> Tensor:
x
) # Done because activation_fn returns a torch.nn.functional
if self.skip is not None:
return x + self.skip(input)
if not self.weighted_skip:
return x + input
else:
return x + self.skip(input)
else:
return x

Expand Down