Skip to content

Commit

Permalink
Switching from BCELoss to BCEWithLogitsLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Feb 7, 2019
1 parent 5376b54 commit 6d4aaa1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
26 changes: 20 additions & 6 deletions chemprop/models/model.py
Expand Up @@ -9,6 +9,18 @@
class MoleculeModel(nn.Module):
"""A MoleculeModel is a model which contains a message passing network following by feed-forward layers."""

def __init__(self, classification: bool):
"""
Initializes the MoleculeModel.
:param classification: Whether the model is a classification model.
"""
super(MoleculeModel, self).__init__()

self.classification = classification
if self.classification:
self.sigmoid = nn.Sigmoid()

def create_encoder(self, args: Namespace):
"""
Creates the message passing encoder for the model.
Expand Down Expand Up @@ -53,10 +65,6 @@ def create_ffn(self, args: Namespace):
nn.Linear(args.ffn_hidden_size, args.output_size),
])

# Classification
if args.dataset_type == 'classification':
ffn.append(nn.Sigmoid())

# Create FFN model
self.ffn = nn.Sequential(*ffn)

Expand All @@ -67,7 +75,13 @@ def forward(self, *input):
:param input: Input.
:return: The output of the MoleculeModel.
"""
return self.ffn(self.encoder(*input))
output = self.ffn(self.encoder(*input))

# Don't apply sigmoid during training b/c using BCEWithLogitsLoss
if self.classification and not self.training:
output = self.sigmoid(output)

return output


def build_model(args: Namespace) -> nn.Module:
Expand All @@ -80,7 +94,7 @@ def build_model(args: Namespace) -> nn.Module:
output_size = args.num_tasks
args.output_size = output_size

model = MoleculeModel()
model = MoleculeModel(classification=args.dataset_type == 'classification')
model.create_encoder(args)
model.create_ffn(args)

Expand Down
2 changes: 1 addition & 1 deletion chemprop/utils.py
Expand Up @@ -144,7 +144,7 @@ def get_loss_func(args: Namespace) -> nn.Module:
:return: A PyTorch loss function.
"""
if args.dataset_type == 'classification':
return nn.BCELoss(reduction='none')
return nn.BCEWithLogitsLoss(reduction='none')

if args.dataset_type == 'regression':
return nn.MSELoss(reduction='none')
Expand Down

0 comments on commit 6d4aaa1

Please sign in to comment.