Skip to content

Commit

Permalink
Merge pull request #706 from sarthakpati/705-add-matthews-correlation…
Browse files Browse the repository at this point in the history
…-coefficient-loss

Added matthews correlation coefficient loss
  • Loading branch information
sarthakpati committed Aug 8, 2023
2 parents b5338fe + b68b0ee commit d12ca21
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 39 deletions.
3 changes: 3 additions & 0 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Class weights : ", parameters["class_weights"])
print("Penalty weights: ", parameters["weights"])

else:
scheduler = None

Expand Down
9 changes: 9 additions & 0 deletions GANDLF/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
MCT_loss,
KullbackLeiblerDivergence,
FocalLoss,
MCC_loss,
MCC_log_loss,
)
from .regression import CE, CEL, MSE_loss, L1_loss
from .hybrid import DCCE, DCCE_Logits, DC_Focal
Expand All @@ -17,7 +19,14 @@
"dc": MCD_loss,
"dice": MCD_loss,
"dc_log": MCD_log_loss,
"dclog": MCD_log_loss,
"dice_log": MCD_log_loss,
"dicelog": MCD_log_loss,
"mcc": MCC_loss,
"mcc_log": MCC_log_loss,
"mcclog": MCC_log_loss,
"mathews": MCC_loss,
"mathews_log": MCC_log_loss,
"dcce": DCCE,
"dcce_logits": DCCE_Logits,
"ce": CE,
Expand Down
183 changes: 154 additions & 29 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@


# Dice scores and dice losses
def dice(predicted, target) -> torch.Tensor:
def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
This function computes a dice score between two tensors.
Args:
predicted (_type_): Predicted value by the network.
target (_type_): Required target label to match the predicted with
predicted (torch.Tensor): Predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
Returns:
torch.Tensor: The computed dice score.
Expand All @@ -25,76 +25,197 @@ def dice(predicted, target) -> torch.Tensor:
return dice_score


def MCD(predicted, target, num_class, weights=None, ignore_class=None, loss_type=0):
def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
This function computes the Matthews Correlation Coefficient (MCC) between two tensors. Adapted from https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py.
Args:
predictions (torch.Tensor): The predicted value by the network.
targets (torch.Tensor): Required target label to match the predicted with
Returns:
torch.Tensor: The computed MCC score.
"""
tp = torch.sum(torch.mul(predictions, targets))
tn = torch.sum(torch.mul((1 - predictions), (1 - targets)))
fp = torch.sum(torch.mul(predictions, (1 - targets)))
fn = torch.sum(torch.mul((1 - predictions), targets))

numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
# Adding epsilon to the denominator to avoid divide-by-zero errors.
denominator = (
torch.sqrt(
torch.add(tp, 1, fp)
* torch.add(tp, 1, fn)
* torch.add(tn, 1, fp)
* torch.add(tn, 1, fn)
)
+ torch.finfo(torch.float32).eps
)

return torch.div(numerator.sum(), denominator.sum())


def generic_loss_calculator(
predicted: torch.Tensor,
target: torch.Tensor,
num_class: int,
loss_criteria,
weights: list = None,
ignore_class: int = None,
loss_type: int = 0,
) -> torch.Tensor:
"""
This function computes the mean class dice score between two tensors
Args:
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
num_class (int): Number of classes (including the background class)
loss_criteria (function): Loss function to use
weights (list, optional): Dice weights for each class (excluding the background class), defaults to None
ignore_class (int, optional): Class to ignore, defaults to None
loss_type (int, optional): Type of loss to compute, defaults to 0
loss_type (int, optional): Type of loss to compute, defaults to 0. The options are:
0: no loss, normal dice calculation
1: dice loss, (1-dice)
2: log dice, -log(dice)
Returns:
torch.Tensor: Mean Class Dice score
"""
accumulated_loss = 0
# default to a ridiculous value so that it is ignored by default
ignore_class = -1e10 if ignore_class is None else ignore_class

for class_index in range(num_class):
if class_index != ignore_class:
current_loss = loss_criteria(
predicted[:, class_index, ...], target[:, class_index, ...]
)

# subtract from 1 because this is supposed to be a loss
default_loss = 1 - current_loss
if loss_type == 2 or loss_type == "log":
# negative because we want positive losses, and add epsilon to avoid infinities
current_loss = -torch.log(current_loss + torch.finfo(torch.float32).eps)
else:
current_loss = default_loss

# multiply by appropriate weight if provided
if weights is not None:
current_loss = current_loss * weights[class_index]

accumulated_loss += current_loss

acc_dice = 0
if weights is None:
accumulated_loss /= num_class

for i in range(num_class): # 0 is background
currentDice = dice(predicted[:, i, ...], target[:, i, ...])
return accumulated_loss

if loss_type == 1:
currentDice = 1 - currentDice # subtract from 1 because this is a loss
elif loss_type == 2:
# negative because we want positive losses
currentDice = -torch.log(currentDice + torch.finfo(torch.float32).eps)

if weights is not None:
currentDice = currentDice * weights[i] # multiply by weight
def MCD_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Dice loss between two tensors. These weights should be the penalty weights, not dice weights.
acc_dice += currentDice
Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters
if weights is None:
acc_dice /= num_class # we should not be considering 0
Returns:
torch.Tensor: The computed MCC loss.
"""
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
None,
1,
)

return acc_dice

def MCD_log_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Dice loss between two tensors with log. These weights should be the penalty weights, not dice weights.
Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters
def MCD_loss(predicted, target, params):
Returns:
torch.Tensor: The computed MCC loss.
"""
These weights should be the penalty weights, not dice weights
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
None,
2,
)


def MCC_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Matthews Correlation Coefficient (MCC) loss between two tensors. These weights should be the penalty weights, not dice weights.
Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters
Returns:
torch.Tensor: The computed MCC loss.
"""
return MCD(
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
None,
1,
)


def MCD_log_loss(predicted, target, params):
def MCC_log_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
These weights should be the penalty weights, not dice weights
This function computes the Matthews Correlation Coefficient (MCC) loss between two tensors with log. These weights should be the penalty weights, not dice weights.
Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters
Returns:
torch.Tensor: The computed MCC loss.
"""
return MCD(
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
None,
2,
)


def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
def tversky_loss(
predicted: torch.Tensor, target: torch.Tensor, alpha: float = 0.5, beta: float = 0.5
) -> torch.Tensor:
"""
This function calculates the Tversky loss between two tensors.
Expand Down Expand Up @@ -127,7 +248,9 @@ def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
return loss


def MCT_loss(predicted, target, params=None):
def MCT_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict = None
) -> torch.Tensor:
"""
This function calculates the Multi-Class Tversky loss between two tensors.
Expand Down Expand Up @@ -171,7 +294,9 @@ def KullbackLeiblerDivergence(mu, logvar, params=None):
return loss.mean()


def FocalLoss(predicted, target, params=None):
def FocalLoss(
predicted: torch.Tensor, target: torch.Tensor, params: dict = None
) -> torch.Tensor:
"""
This function calculates the Focal loss between two tensors.
Expand All @@ -191,7 +316,7 @@ def FocalLoss(predicted, target, params=None):

def _focal_loss(preds, target, gamma, size_average=True):
"""
Internal helper function to calcualte focal loss for a single class.
Internal helper function to calculate focal loss for a single class.
Args:
preds (torch.Tensor): predicted generally by the network
Expand Down
3 changes: 1 addition & 2 deletions GANDLF/utils/modelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def optimize_and_save_model(model, params, path, onnx_export=True):
input_names=["input"],
output_names=["output"],
)

ov_output_dir = os.path.dirname(os.path.abspath(path))
except RuntimeWarning:
print("WARNING: Cannot export to ONNX model.")
return
Expand All @@ -99,6 +97,7 @@ def optimize_and_save_model(model, params, path, onnx_export=True):
import openvino as ov
from openvino.tools.mo import convert_model
from openvino.runtime import get_version

openvino_present = False
# check for the correct openvino version to prevent inadvertent api breaks
if "2023.0.1" in get_version():
Expand Down
2 changes: 1 addition & 1 deletion docs/customize.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ This file contains mid-level information regarding various parameters that can b
- Defined in the `loss_function` parameter of the model configuration.
- By passing `weighted_loss: True`, the loss function will be weighted by the inverse of the class frequency.
- This parameter controls the function which the model is trained. All options can be found [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/losses/__init__.py). Some examples are:
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`)
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`), matthews (`mcc`)
- Classification/regression: mean squared error (`mse`)
- And many more.

Expand Down
2 changes: 1 addition & 1 deletion samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ scheduler:
max_lr: 1,
}
# Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please)
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ...
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), focal/dc_focal, mcc/mcc_log, mse () ...
# mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
# focal is the focal loss and can define 2 variables: gamma and size_average
# use mse_torch for regression/classification problems and dice for segmentation
Expand Down
17 changes: 11 additions & 6 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ def test_train_metrics_regression_rad_2d(device):

def test_train_losses_segmentation_rad_2d(device):
print("23: Starting 2D Rad segmentation tests for losses")

# healper function to read and parse yaml and return parameters
def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = parseConfig(
Expand Down Expand Up @@ -1242,15 +1243,19 @@ def get_parameters_after_alteration(loss_type: str) -> dict:
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
return parameters, training_data

# loop through selected models and train for single epoch
for loss_type in [
"dc",
"dc_log",
"dcce",
"dcce_logits",
"dc",
"dc_log",
"dcce",
"dcce_logits",
"tversky",
"focal",
"dc_focal"]:
"focal",
"dc_focal",
"mcc",
"mcc_log",
]:
parameters, training_data = get_parameters_after_alteration(loss_type)
sanitize_outputDir()
TrainingManager(
Expand Down

0 comments on commit d12ca21

Please sign in to comment.