Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added matthews correlation coefficient loss #706

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0452c26
updated default weight decay
sarthakpati Jul 28, 2023
a908b5b
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF
sarthakpati Jul 28, 2023
5122250
Merge branch 'mlcommons:master' into master
sarthakpati Aug 2, 2023
855d297
added print of weights
sarthakpati Aug 2, 2023
8534c11
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF
sarthakpati Aug 2, 2023
3b7be73
Merge branch 'mlcommons:master' into 705-add-matthews-correlation-coe…
sarthakpati Aug 5, 2023
6d476a2
added comments and expected data types
sarthakpati Aug 5, 2023
44ac9a9
added mcc function
sarthakpati Aug 5, 2023
7bca7d1
using torch epsilon
sarthakpati Aug 5, 2023
cc67266
added 2 types of mcc loss
sarthakpati Aug 5, 2023
2649ce1
updated global dict
sarthakpati Aug 5, 2023
5d4eeb6
renamed
sarthakpati Aug 5, 2023
79c0aba
typo fix
sarthakpati Aug 5, 2023
5efb240
updated variable names and comments
sarthakpati Aug 5, 2023
c8bd0bd
initialized a default loss
sarthakpati Aug 5, 2023
99d5c7f
black .
sarthakpati Aug 5, 2023
1513cca
this is causing problems
sarthakpati Aug 5, 2023
113f8a3
added new losses to test
sarthakpati Aug 5, 2023
58aea3e
no need for this variable
sarthakpati Aug 5, 2023
2852131
added doc
sarthakpati Aug 5, 2023
8c5808b
updated doc
sarthakpati Aug 5, 2023
b26360d
updated all options
sarthakpati Aug 5, 2023
f493d66
updated checkout and renamed job names for clarity
sarthakpati Aug 6, 2023
0c09047
Merge branch 'sarthakpati-patch-2' of https://github.com/mlcommons/Ga…
sarthakpati Aug 6, 2023
d74a318
Merge branch 'sarthakpati-patch-3' of https://github.com/mlcommons/Ga…
sarthakpati Aug 6, 2023
b4316a4
Merge branch 'master' into 705-add-matthews-correlation-coefficient-loss
sarthakpati Aug 7, 2023
89ab7d1
Merge branch 'master' into 705-add-matthews-correlation-coefficient-loss
sarthakpati Aug 7, 2023
b68b0ee
was missing a comma
sarthakpati Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Class weights : ", parameters["class_weights"])
print("Penalty weights: ", parameters["weights"])

else:
scheduler = None

Expand Down
9 changes: 9 additions & 0 deletions GANDLF/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
MCT_loss,
KullbackLeiblerDivergence,
FocalLoss,
MCC_loss,
MCC_log_loss,
)
from .regression import CE, CEL, MSE_loss, L1_loss
from .hybrid import DCCE, DCCE_Logits, DC_Focal
Expand All @@ -17,7 +19,14 @@
"dc": MCD_loss,
"dice": MCD_loss,
"dc_log": MCD_log_loss,
"dclog": MCD_log_loss,
"dice_log": MCD_log_loss,
"dicelog": MCD_log_loss,
"mcc": MCC_loss,
"mcc_log": MCC_log_loss,
"mcclog": MCC_log_loss,
"mathews": MCC_loss,
"mathews_log": MCC_log_loss,
"dcce": DCCE,
"dcce_logits": DCCE_Logits,
"ce": CE,
Expand Down
183 changes: 154 additions & 29 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@


# Dice scores and dice losses
def dice(predicted, target) -> torch.Tensor:
def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
This function computes a dice score between two tensors.

Args:
predicted (_type_): Predicted value by the network.
target (_type_): Required target label to match the predicted with
predicted (torch.Tensor): Predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with

Returns:
torch.Tensor: The computed dice score.
Expand All @@ -25,76 +25,197 @@ def dice(predicted, target) -> torch.Tensor:
return dice_score


def MCD(predicted, target, num_class, weights=None, ignore_class=None, loss_type=0):
def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
This function computes the Matthews Correlation Coefficient (MCC) between two tensors. Adapted from https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py.

Args:
predictions (torch.Tensor): The predicted value by the network.
targets (torch.Tensor): Required target label to match the predicted with

Returns:
torch.Tensor: The computed MCC score.
"""
tp = torch.sum(torch.mul(predictions, targets))
tn = torch.sum(torch.mul((1 - predictions), (1 - targets)))
fp = torch.sum(torch.mul(predictions, (1 - targets)))
fn = torch.sum(torch.mul((1 - predictions), targets))

numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
# Adding epsilon to the denominator to avoid divide-by-zero errors.
denominator = (
torch.sqrt(
torch.add(tp, 1, fp)
* torch.add(tp, 1, fn)
* torch.add(tn, 1, fp)
* torch.add(tn, 1, fn)
)
+ torch.finfo(torch.float32).eps
)

return torch.div(numerator.sum(), denominator.sum())


def generic_loss_calculator(
predicted: torch.Tensor,
target: torch.Tensor,
num_class: int,
loss_criteria,
weights: list = None,
ignore_class: int = None,
loss_type: int = 0,
) -> torch.Tensor:
"""
This function computes the mean class dice score between two tensors

Args:
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
num_class (int): Number of classes (including the background class)
loss_criteria (function): Loss function to use
weights (list, optional): Dice weights for each class (excluding the background class), defaults to None
ignore_class (int, optional): Class to ignore, defaults to None
loss_type (int, optional): Type of loss to compute, defaults to 0
loss_type (int, optional): Type of loss to compute, defaults to 0. The options are:
0: no loss, normal dice calculation
1: dice loss, (1-dice)
2: log dice, -log(dice)

Returns:
torch.Tensor: Mean Class Dice score
"""
accumulated_loss = 0
# default to a ridiculous value so that it is ignored by default
ignore_class = -1e10 if ignore_class is None else ignore_class

for class_index in range(num_class):
if class_index != ignore_class:
current_loss = loss_criteria(
predicted[:, class_index, ...], target[:, class_index, ...]
)

# subtract from 1 because this is supposed to be a loss
default_loss = 1 - current_loss
if loss_type == 2 or loss_type == "log":
# negative because we want positive losses, and add epsilon to avoid infinities
current_loss = -torch.log(current_loss + torch.finfo(torch.float32).eps)
else:
current_loss = default_loss

# multiply by appropriate weight if provided
if weights is not None:
current_loss = current_loss * weights[class_index]

accumulated_loss += current_loss

acc_dice = 0
if weights is None:
accumulated_loss /= num_class

for i in range(num_class): # 0 is background
currentDice = dice(predicted[:, i, ...], target[:, i, ...])
return accumulated_loss

if loss_type == 1:
currentDice = 1 - currentDice # subtract from 1 because this is a loss
elif loss_type == 2:
# negative because we want positive losses
currentDice = -torch.log(currentDice + torch.finfo(torch.float32).eps)

if weights is not None:
currentDice = currentDice * weights[i] # multiply by weight
def MCD_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Dice loss between two tensors. These weights should be the penalty weights, not dice weights.

acc_dice += currentDice
Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters

if weights is None:
acc_dice /= num_class # we should not be considering 0
Returns:
torch.Tensor: The computed MCC loss.
"""
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
None,
1,
)

return acc_dice

def MCD_log_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Dice loss between two tensors with log. These weights should be the penalty weights, not dice weights.

Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters

def MCD_loss(predicted, target, params):
Returns:
torch.Tensor: The computed MCC loss.
"""
These weights should be the penalty weights, not dice weights
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
None,
2,
)


def MCC_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
This function computes the Matthews Correlation Coefficient (MCC) loss between two tensors. These weights should be the penalty weights, not dice weights.

Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters

Returns:
torch.Tensor: The computed MCC loss.
"""
return MCD(
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
None,
1,
)


def MCD_log_loss(predicted, target, params):
def MCC_log_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
"""
These weights should be the penalty weights, not dice weights
This function computes the Matthews Correlation Coefficient (MCC) loss between two tensors with log. These weights should be the penalty weights, not dice weights.

Args:
predicted (torch.Tensor): The predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
params (dict): Dictionary of parameters

Returns:
torch.Tensor: The computed MCC loss.
"""
return MCD(
return generic_loss_calculator(
predicted,
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
None,
2,
)


def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
def tversky_loss(
predicted: torch.Tensor, target: torch.Tensor, alpha: float = 0.5, beta: float = 0.5
) -> torch.Tensor:
"""
This function calculates the Tversky loss between two tensors.

Expand Down Expand Up @@ -127,7 +248,9 @@ def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
return loss


def MCT_loss(predicted, target, params=None):
def MCT_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict = None
) -> torch.Tensor:
"""
This function calculates the Multi-Class Tversky loss between two tensors.

Expand Down Expand Up @@ -171,7 +294,9 @@ def KullbackLeiblerDivergence(mu, logvar, params=None):
return loss.mean()


def FocalLoss(predicted, target, params=None):
def FocalLoss(
predicted: torch.Tensor, target: torch.Tensor, params: dict = None
) -> torch.Tensor:
"""
This function calculates the Focal loss between two tensors.

Expand All @@ -191,7 +316,7 @@ def FocalLoss(predicted, target, params=None):

def _focal_loss(preds, target, gamma, size_average=True):
"""
Internal helper function to calcualte focal loss for a single class.
Internal helper function to calculate focal loss for a single class.

Args:
preds (torch.Tensor): predicted generally by the network
Expand Down
3 changes: 1 addition & 2 deletions GANDLF/utils/modelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def optimize_and_save_model(model, params, path, onnx_export=True):
input_names=["input"],
output_names=["output"],
)

ov_output_dir = os.path.dirname(os.path.abspath(path))
except RuntimeWarning:
print("WARNING: Cannot export to ONNX model.")
return
Expand All @@ -99,6 +97,7 @@ def optimize_and_save_model(model, params, path, onnx_export=True):
import openvino as ov
from openvino.tools.mo import convert_model
from openvino.runtime import get_version

openvino_present = False
# check for the correct openvino version to prevent inadvertent api breaks
if "2023.0.1" in get_version():
Expand Down
2 changes: 1 addition & 1 deletion docs/customize.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ This file contains mid-level information regarding various parameters that can b
- Defined in the `loss_function` parameter of the model configuration.
- By passing `weighted_loss: True`, the loss function will be weighted by the inverse of the class frequency.
- This parameter controls the function which the model is trained. All options can be found [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/losses/__init__.py). Some examples are:
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`)
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`), matthews (`mcc`)
- Classification/regression: mean squared error (`mse`)
- And many more.

Expand Down
2 changes: 1 addition & 1 deletion samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ scheduler:
max_lr: 1,
}
# Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please)
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ...
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), focal/dc_focal, mcc/mcc_log, mse () ...
# mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
# focal is the focal loss and can define 2 variables: gamma and size_average
# use mse_torch for regression/classification problems and dice for segmentation
Expand Down
17 changes: 11 additions & 6 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ def test_train_metrics_regression_rad_2d(device):

def test_train_losses_segmentation_rad_2d(device):
print("23: Starting 2D Rad segmentation tests for losses")

# healper function to read and parse yaml and return parameters
def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = parseConfig(
Expand Down Expand Up @@ -1242,15 +1243,19 @@ def get_parameters_after_alteration(loss_type: str) -> dict:
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
return parameters, training_data

# loop through selected models and train for single epoch
for loss_type in [
"dc",
"dc_log",
"dcce",
"dcce_logits",
"dc",
"dc_log",
"dcce",
"dcce_logits",
"tversky",
"focal",
"dc_focal"]:
"focal",
"dc_focal",
"mcc",
"mcc_log",
]:
parameters, training_data = get_parameters_after_alteration(loss_type)
sanitize_outputDir()
TrainingManager(
Expand Down
Loading