Skip to content

Commit

Permalink
Replace multiclass mcc with 1-mcc for loss (#332)
Browse files Browse the repository at this point in the history
* Replace multiclass mcc with 1-mcc for loss

* Loss = 1 for undefined mcc

* Update unit test for mcc multiclass loss

* Remove batch size normalization from mcc loss

* Format train.py file

* Remove mask sum multiplier from mcc

* Specify device inside mcc loss function
  • Loading branch information
cjmcgill committed Oct 18, 2022
1 parent 442a160 commit 2d00296
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 41 deletions.
7 changes: 4 additions & 3 deletions chemprop/train/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def mcc_multiclass_loss(
mask: torch.tensor,
) -> torch.tensor:
"""
A multiclass loss using a soft version of the Matthews Correlation Coefficient. Multiclass definition follows the version in sklearn documentation.
A multiclass loss using a soft version of the Matthews Correlation Coefficient. Multiclass definition follows the version in sklearn documentation (https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-correlation-coefficient).
:param predictions: Model predictions with shape(batch_size, classes).
:param targets: Target values with shape(batch_size).
Expand Down Expand Up @@ -146,9 +146,10 @@ def mcc_multiclass_loss(
cov_ytyt = n_samples**2 - torch.dot(t_sum, t_sum)

if cov_ypyp * cov_ytyt == 0:
loss = torch.tensor(0.0)
loss = torch.tensor(1.0, device=torch_device)
else:
loss = cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp)
mcc = cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp)
loss = 1 - mcc

return loss

Expand Down
78 changes: 45 additions & 33 deletions chemprop/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR


def train(model: MoleculeModel,
data_loader: MoleculeDataLoader,
loss_func: Callable,
optimizer: Optimizer,
scheduler: _LRScheduler,
args: TrainArgs,
n_iter: int = 0,
logger: logging.Logger = None,
writer: SummaryWriter = None) -> int:
def train(
model: MoleculeModel,
data_loader: MoleculeDataLoader,
loss_func: Callable,
optimizer: Optimizer,
scheduler: _LRScheduler,
args: TrainArgs,
n_iter: int = 0,
logger: logging.Logger = None,
writer: SummaryWriter = None,
) -> int:
"""
Trains a model for an epoch.
Expand Down Expand Up @@ -50,65 +52,75 @@ def train(model: MoleculeModel,
batch.atom_features(), batch.bond_features(), batch.data_weights()

mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks)
targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks)
targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks)

if args.target_weights is not None:
target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks)
else:
target_weights = torch.ones(targets.shape[1]).unsqueeze(0)
data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1)
data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1)

if args.loss_function == 'bounded_mse':
lt_target_batch = batch.lt_targets() # shape(batch, tasks)
gt_target_batch = batch.gt_targets() # shape(batch, tasks)
if args.loss_function == "bounded_mse":
lt_target_batch = batch.lt_targets() # shape(batch, tasks)
gt_target_batch = batch.gt_targets() # shape(batch, tasks)
lt_target_batch = torch.tensor(lt_target_batch)
gt_target_batch = torch.tensor(gt_target_batch)

# Run model
model.zero_grad()
preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch)
preds = model(
mol_batch,
features_batch,
atom_descriptors_batch,
atom_features_batch,
bond_features_batch,
)

# Move tensors to correct device
torch_device = preds.device
mask = mask.to(torch_device)
targets = targets.to(torch_device)
target_weights = target_weights.to(torch_device)
data_weights = data_weights.to(torch_device)
if args.loss_function == 'bounded_mse':
if args.loss_function == "bounded_mse":
lt_target_batch = lt_target_batch.to(torch_device)
gt_target_batch = gt_target_batch.to(torch_device)

# Calculate losses
if args.loss_function == 'mcc' and args.dataset_type == 'classification':
loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0)
elif args.loss_function == 'mcc': # multiclass dataset type
if args.loss_function == "mcc" and args.dataset_type == "classification":
loss = loss_func(preds, targets, data_weights, mask) * target_weights.squeeze(0)
elif args.loss_function == "mcc": # multiclass dataset type
targets = targets.long()
target_losses = []
for target_index in range(preds.size(1)):
target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0)
target_losses.append(target_loss)
loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0)
elif args.dataset_type == 'multiclass':
loss = torch.cat(target_losses) * target_weights.squeeze(0)
elif args.dataset_type == "multiclass":
targets = targets.long()
if args.loss_function == 'dirichlet':
if args.loss_function == "dirichlet":
loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
else:
target_losses = []
for target_index in range(preds.size(1)):
target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1)
target_losses.append(target_loss)
loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask
elif args.dataset_type == 'spectra':
elif args.dataset_type == "spectra":
loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask
elif args.loss_function == 'bounded_mse':
elif args.loss_function == "bounded_mse":
loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask
elif args.loss_function == 'evidential':
elif args.loss_function == "evidential":
loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
elif args.loss_function == 'dirichlet': # classification
elif args.loss_function == "dirichlet": # classification
loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
else:
loss = loss_func(preds, targets) * target_weights * data_weights * mask
loss = loss.sum() / mask.sum()

if args.loss_function == "mcc":
loss = loss.mean()
else:
loss = loss.sum() / mask.sum()

loss_sum += loss.item()
iter_count += 1
Expand All @@ -131,14 +143,14 @@ def train(model: MoleculeModel,
loss_avg = loss_sum / iter_count
loss_sum = iter_count = 0

lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')
lrs_str = ", ".join(f"lr_{i} = {lr:.4e}" for i, lr in enumerate(lrs))
debug(f"Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}")

if writer is not None:
writer.add_scalar('train_loss', loss_avg, n_iter)
writer.add_scalar('param_norm', pnorm, n_iter)
writer.add_scalar('gradient_norm', gnorm, n_iter)
writer.add_scalar("train_loss", loss_avg, n_iter)
writer.add_scalar("param_norm", pnorm, n_iter)
writer.add_scalar("gradient_norm", gnorm, n_iter)
for i, lr in enumerate(lrs):
writer.add_scalar(f'learning_rate_{i}', lr, n_iter)
writer.add_scalar(f"learning_rate_{i}", lr, n_iter)

return n_iter
10 changes: 5 additions & 5 deletions tests/test_unit/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,35 +226,35 @@ def test_evidential_wrong_dimensions(alphas, targets):
torch.tensor([0, 0, 0], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.0,
1 - 0.0,
),
(
torch.tensor([[0.8, 0.1, 0.1], [0.7, 0.2, 0.1], [0.6, 0.3, 0.1]], dtype=float),
torch.tensor([1, 2, 0], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.0,
1 - 0.0,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.6123724356957946,
1 - 0.6123724356957946,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[0.5], [1], [1.5]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.7462025072446364,
1 - 0.7462025072446364,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([False, True, True], dtype=bool),
1.0,
1 - 1.0,
),
],
)
Expand Down

0 comments on commit 2d00296

Please sign in to comment.