Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated the metrics output #687

Merged
22 changes: 9 additions & 13 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
resample_image,
reverse_one_hot,
get_ground_truths_and_predictions_tensor,
print_and_format_metrics,
)
from GANDLF.metrics import overall_stats
from tqdm import tqdm
Expand Down Expand Up @@ -498,19 +499,14 @@ def validate_network(
average_epoch_valid_metric = overall_stats(
predictions_array, ground_truth_array, params
)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], np.ndarray):
to_print = (
total_epoch_valid_metric[metric] / len(valid_dataloader)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / len(valid_dataloader)
average_epoch_valid_metric[metric] = to_print
for metric in average_epoch_valid_metric.keys():
print(
" Epoch Final " + mode + " " + metric + " : ",
average_epoch_valid_metric[metric],
)
average_epoch_valid_metric = print_and_format_metrics(
average_epoch_valid_metric,
total_epoch_valid_metric,
params["metrics"],
mode,
len(valid_dataloader),
)

else:
average_epoch_valid_loss, average_epoch_valid_metric = 0, {}

Expand Down
42 changes: 18 additions & 24 deletions GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
print_model_summary,
get_ground_truths_and_predictions_tensor,
get_model_dict,
print_and_format_metrics,
)
from GANDLF.metrics import overall_stats
from GANDLF.logger import Logger
Expand Down Expand Up @@ -204,19 +205,13 @@ def train_network(model, train_dataloader, optimizer, params):
average_epoch_train_metric = overall_stats(
predictions_array, ground_truth_array, params
)
for metric in params["metrics"]:
if isinstance(total_epoch_train_metric[metric], np.ndarray):
to_print = (
total_epoch_train_metric[metric] / len(train_dataloader)
).tolist()
else:
to_print = total_epoch_train_metric[metric] / len(train_dataloader)
average_epoch_train_metric[metric] = to_print
for metric in average_epoch_train_metric.keys():
print(
" Epoch Final train " + metric + " : ",
average_epoch_train_metric[metric],
)
average_epoch_train_metric = print_and_format_metrics(
average_epoch_train_metric,
total_epoch_train_metric,
params["metrics"],
"train",
len(train_dataloader),
)

return average_epoch_train_loss, average_epoch_train_metric

Expand Down Expand Up @@ -348,15 +343,12 @@ def training_loop(
overall_metrics = overall_stats(torch.Tensor([1]), torch.Tensor([1]), params)
elif params["problem_type"] == "classification":
# this is just used to generate the headers for the overall stats
org_num_classes = params["model"]["num_classes"]
params["model"]["num_classes"] = 3
temp_tensor = torch.randint(0, params["model"]["num_classes"], (5,))
overall_metrics = overall_stats(
torch.Tensor([0, 0, 2, 2, 1, 2]).to(dtype=torch.int32),
torch.Tensor([0, 0, 2, 2, 1, 2]).to(dtype=torch.int32),
temp_tensor.to(dtype=torch.int32),
temp_tensor.to(dtype=torch.int32),
params,
)
# original number of classes are restored
params["model"]["num_classes"] = org_num_classes

metrics_log = params["metrics"].copy()
if calculate_overall_metrics:
Expand All @@ -373,13 +365,15 @@ def training_loop(
logger_csv_filename=os.path.join(output_dir, "logs_validation.csv"),
metrics=metrics_log,
)
test_logger = Logger(
logger_csv_filename=os.path.join(output_dir, "logs_testing.csv"),
metrics=metrics_log,
)
if testingDataDefined:
test_logger = Logger(
logger_csv_filename=os.path.join(output_dir, "logs_testing.csv"),
metrics=metrics_log,
)
train_logger.write_header(mode="train")
valid_logger.write_header(mode="valid")
test_logger.write_header(mode="test")
if testingDataDefined:
test_logger.write_header(mode="test")

if "medcam" in params:
model = medcam.inject(
Expand Down
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
get_array_from_image_or_tensor,
suppress_stdout_stderr,
set_determinism,
print_and_format_metrics,
)

from .modelio import (
Expand Down
65 changes: 65 additions & 0 deletions GANDLF/utils/generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, datetime, sys
from copy import deepcopy
import random
import numpy as np
import torch
Expand Down Expand Up @@ -201,3 +202,67 @@ def set_determinism(seed=42):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


def print_and_format_metrics(
cohort_level_metrics,
sample_level_metrics,
metrics_dict_from_parameters,
mode,
length_of_dataloader,
):
"""
This function prints and formats the metrics.

Args:
cohort_level_metrics (dict): The cohort level metrics calculated from the GANDLF.metrics.overall_stats function.
sample_level_metrics (dict): The sample level metrics calculated from separate samples from the dataloader(s).
metrics_dict_from_parameters (dict): The metrics dictionary to populate.
mode (str): The mode of the metrics (train, val, test).
length_of_dataloader (int): The length of the dataloader.

Returns:
dict: The metrics dictionary populated with the metrics.
"""

def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
"""
Helper function updates the metrics dictionary to have a single string for each metric.

Args:
input_metrics_dict (dict): The input metrics dictionary.
Returns:
dict: The updated metrics dictionary.
"""
print(input_metrics_dict)
output_metrics_dict = deepcopy(input_metrics_dict)
for metric in input_metrics_dict.keys():
if isinstance(input_metrics_dict[metric], list):
output_metrics_dict[metric] = ("_").join(
str(input_metrics_dict[metric])
.replace("[", "")
.replace("]", "")
.replace(" ", "")
.split(",")
)

print(output_metrics_dict)
return output_metrics_dict

output_metrics_dict = deepcopy(cohort_level_metrics)
for metric in metrics_dict_from_parameters:
if isinstance(sample_level_metrics[metric], np.ndarray):
to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist()
else:
to_print = sample_level_metrics[metric] / length_of_dataloader
output_metrics_dict[metric] = to_print
for metric in output_metrics_dict.keys():
print(
" Epoch Final " + mode + " " + metric + " : ",
output_metrics_dict[metric],
)
output_metrics_dict = __update_metric_from_list_to_single_string(
output_metrics_dict
)

return output_metrics_dict
Loading