Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IN PROGRESS Logging setup #842

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion GANDLF/cli/main_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
populate_header_in_parameters,
parseTrainingCSV,
parseTestingCSV,
setup_logger,
)


Expand Down Expand Up @@ -59,6 +60,11 @@ def main_run(
parameters["output_dir"] = model_dir
Path(parameters["output_dir"]).mkdir(parents=True, exist_ok=True)

# setup logger
logger, parameters["logs_dir"], parameters["logger_name"] = setup_logger(
output_dir=output_dir, verbose=parameters.get("verbose", False)
)

if "-1" in device:
device = "cpu"

Expand Down Expand Up @@ -87,7 +93,7 @@ def main_run(
), "The training and testing CSVs do not have the same header information."

parameters = populate_header_in_parameters(parameters, headers_train)
# if we are here, it is assumed that the user wants to do training
logger.debug("if we are here, it is assumed that the user wants to do training")
if train_mode:
TrainingManager_split(
dataframe_train=data_train,
Expand Down
8 changes: 3 additions & 5 deletions GANDLF/cli/preprocess_and_save.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, sys, pickle
import os, pickle, warnings
from typing import Optional
from pathlib import Path
import SimpleITK as sitk
Expand Down Expand Up @@ -88,10 +88,8 @@ def preprocess_and_save(
(parameters["patch_sampler"] == "label")
or (isinstance(parameters["patch_sampler"], dict))
) and parameters["q_samples_per_volume"] > 1:
print(
"[WARNING] Label sampling has been enabled but q_samples_per_volume > 1; this has been known to cause issues, so q_samples_per_volume will be hard-coded to 1 during preprocessing. Please contact GaNDLF developers for more information",
file=sys.stderr,
flush=True,
warnings.warn(
"Label sampling has been enabled but q_samples_per_volume > 1; this has been known to cause issues, so q_samples_per_volume will be hard-coded to 1 during preprocessing. Please contact GaNDLF developers for more information"
)

for _, (subject) in enumerate(
Expand Down
102 changes: 49 additions & 53 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pathlib
from typing import Optional, Tuple

import logging
import numpy as np
import pandas as pd
import SimpleITK as sitk
Expand All @@ -19,6 +19,7 @@
reverse_one_hot,
get_ground_truths_and_predictions_tensor,
print_and_format_metrics,
setup_logger,
)
from GANDLF.metrics import overall_stats
from tqdm import tqdm
Expand Down Expand Up @@ -46,6 +47,13 @@ def validate_network(
Returns:
Tuple[float, dict]: The average validation loss and the average validation metrics.
"""
if "logger_name" in params:
logger = logging.getLogger(params["logger_name"])
else:
logger, params["logs_dir"], params["logger_name"] = setup_logger(
output_dir=params["output_dir"], verbose=params.get("verbose", False)
)

print("*" * 20)
print("Starting " + mode + " : ")
print("*" * 20)
Expand All @@ -70,9 +78,8 @@ def validate_network(
is_inference = mode == "inference"

# automatic mixed precision - https://pytorch.org/docs/stable/amp.html
if params["verbose"]:
if params["model"]["amp"]:
print("Using Automatic mixed precision", flush=True)
if params["model"]["amp"]:
logger.debug("Using Automatic mixed precision")

if scheduler is None:
current_output_dir = params["output_dir"] # this is in inference mode
Expand Down Expand Up @@ -115,8 +122,7 @@ def validate_network(
for batch_idx, (subject) in enumerate(
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
):
if params["verbose"]:
print("== Current subject:", subject["subject_id"], flush=True)
logger.debug("== Current subject:", subject["subject_id"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug("== Current subject:", subject["subject_id"])
logger.debug(f"== Current subject: {subject['subject_id']}")


# ensure spacing is always present in params and is always subject-specific
params["subject_spacing"] = None
Expand Down Expand Up @@ -247,16 +253,14 @@ def validate_network(
output_prediction = 0 # this is used for regression/classification
current_patch = 0
for patches_batch in patch_loader:
if params["verbose"]:
print(
"=== Current patch:",
current_patch,
", time : ",
get_date_time(),
", location :",
patches_batch[torchio.LOCATION],
flush=True,
)
logger.debug(
"=== Current patch:",
current_patch,
", time : ",
get_date_time(),
", location :",
patches_batch[torchio.LOCATION],
)
Comment on lines +256 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug(
"=== Current patch:",
current_patch,
", time : ",
get_date_time(),
", location :",
patches_batch[torchio.LOCATION],
)
logger.debug(f"=== Current patch: {current_patch}, time : {get_date_time()}, location : {patches_batch[torchio.LOCATION]}")

or maybe even we don't need time here as it is already included in logging format

current_patch += 1
image = (
torch.cat(
Expand All @@ -279,14 +283,12 @@ def validate_network(

if label is not None:
label = label.to(params["device"])
if params["verbose"]:
print(
"=== Validation shapes : label:",
label.shape,
", image:",
image.shape,
flush=True,
)
logger.debug(
"=== Validation shapes : label:",
label.shape,
", image:",
image.shape,
)
Comment on lines +286 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug(
"=== Validation shapes : label:",
label.shape,
", image:",
image.shape,
)
logger.debug(f"=== Validation shapes : label: {label.shape}, image: {image.shape}")


if is_inference:
result = step(model, image, None, params, train=False)
Expand Down Expand Up @@ -427,14 +429,12 @@ def validate_network(
output_prediction.to(torch.float32),
params,
)
if params["verbose"]:
print(
"Full image " + mode + ":: Loss: ",
final_loss,
"; Metric: ",
final_metric,
flush=True,
)
logger.debug(
"Full image " + mode + ":: Loss: ",
final_loss,
"; Metric: ",
final_metric,
)
Comment on lines +432 to +437
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug(
"Full image " + mode + ":: Loss: ",
final_loss,
"; Metric: ",
final_metric,
)
logger.debug(f"Full image {mode}:: Loss: {final_loss}; Metric: {final_metric}")

Because you know logging utils are not like a print, it fails if you pass multiple values instead of printing them together


# # Non network validing related
# loss.cpu().data.item()
Expand All @@ -453,28 +453,24 @@ def validate_network(
total_epoch_valid_metric[metric] += final_metric[metric]

if label_ground_truth is not None:
if params["verbose"]:
# For printing information at halftime during an epoch
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and (
(batch_idx + 1) < len(valid_dataloader)
):
print(
"\nHalf-Epoch Average " + mode + " loss : ",
total_epoch_valid_loss / (batch_idx + 1),
# For printing information at halftime during an epoch
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and (
(batch_idx + 1) < len(valid_dataloader)
):
logger.debug(
"\nHalf-Epoch Average " + mode + " loss : ",
total_epoch_valid_loss / (batch_idx + 1),
)
Comment on lines +460 to +463
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug(
"\nHalf-Epoch Average " + mode + " loss : ",
total_epoch_valid_loss / (batch_idx + 1),
)
avg_loss = total_epoch_valid_loss / (batch_idx + 1)
logger.debug(f"\nHalf-Epoch Average {mode} loss : {avg_loss}")

for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], np.ndarray):
to_print = (
total_epoch_valid_metric[metric] / (batch_idx + 1)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / (batch_idx + 1)
logger.debug(
"Half-Epoch Average " + mode + " " + metric + " : ", to_print
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Half-Epoch Average " + mode + " " + metric + " : ", to_print
"Half-Epoch Average " + mode + " " + metric + " : " + to_print

(or even rewrite it to f-strings)

)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], np.ndarray):
to_print = (
total_epoch_valid_metric[metric] / (batch_idx + 1)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / (
batch_idx + 1
)
print(
"Half-Epoch Average " + mode + " " + metric + " : ",
to_print,
)

if params["medcam_enabled"] and params["model"]["type"] == "torch":
model.disable_medcam()
Expand Down
17 changes: 13 additions & 4 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Tuple
from pandas.util import hash_pandas_object
import torch
import torch, logging
from torch.utils.data import DataLoader

from GANDLF.models import get_model
Expand All @@ -12,6 +12,7 @@
parseTrainingCSV,
send_model_to_device,
get_class_imbalance_weights,
setup_logger,
)


Expand Down Expand Up @@ -40,6 +41,14 @@ def create_pytorch_objects(
Returns:
Tuple[ torch.nn.Module, torch.optim.Optimizer, DataLoader, DataLoader, torch.optim.lr_scheduler.LRScheduler, dict, ]: The model, optimizer, train loader, validation loader, scheduler, and parameters.
"""
if "logger_name" in parameters:
logger = logging.getLogger(parameters["logger_name"])
else:
logger, parameters["logs_dir"], parameters["logger_name"] = setup_logger(
output_dir=parameters["output_dir"],
verbose=parameters.get("verbose", False),
)

# initialize train and val loaders
train_loader, val_loader = None, None
headers_to_populate_train, headers_to_populate_val = None, None
Expand All @@ -60,9 +69,9 @@ def create_pytorch_objects(
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Penalty weights : ", parameters["penalty_weights"])
print("Sampling weights: ", parameters["sampling_weights"])
print("Class weights : ", parameters["class_weights"])
logger.debug(f"Penalty weights : {parameters['penalty_weights']}")
logger.debug(f"Sampling weights: {parameters['sampling_weights']}")
logger.debug(f"Class weights : {parameters['class_weights']}")

# get the train loader
train_loader = get_train_loader(parameters)
Expand Down
26 changes: 18 additions & 8 deletions GANDLF/compute/inference_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import cv2
import logging
import numpy as np
from torch.utils.data import DataLoader
from skimage.io import imsave
Expand All @@ -23,6 +24,7 @@
load_ov_model,
print_model_summary,
applyCustomColorMap,
setup_logger,
)

from GANDLF.data.inference_dataloader_histopath import InferTumorSegDataset
Expand All @@ -46,12 +48,20 @@ def inference_loop(
modelDir (str): The path to the directory containing the model to be used for inference.
outputDir (str): The path to the directory where the output of the inference session will be stored.
"""
if "logger_name" in parameters:
logger = logging.getLogger(parameters["logger_name"])
else:
logger, parameters["logs_dir"], parameters["logger_name"] = setup_logger(
output_dir=parameters["output_dir"],
verbose=parameters.get("verbose", False),
)

# Defining our model here according to parameters mentioned in the configuration file
print("Current model type : ", parameters["model"]["type"])
print("Number of dims : ", parameters["model"]["dimension"])
logger.debug("Current model type : ", parameters["model"]["type"])
logger.debug("Number of dims : ", parameters["model"]["dimension"])
if "num_channels" in parameters["model"]:
print("Number of channels : ", parameters["model"]["num_channels"])
print("Number of classes : ", len(parameters["model"]["class_list"]))
logger.debug("Number of channels : ", parameters["model"]["num_channels"])
logger.debug("Number of classes : ", len(parameters["model"]["class_list"]))
parameters["testing_data"] = inferenceDataFromPickle

# ensure outputs are saved properly
Expand Down Expand Up @@ -115,7 +125,7 @@ def inference_loop(
parameters["model"]["IO"] = [input_blob, output_blob]

if not (os.environ.get("HOSTNAME") is None):
print("\nHostname :" + str(os.environ.get("HOSTNAME")), flush=True)
logger.debug("\nHostname :" + str(os.environ.get("HOSTNAME")))

# radiology inference
if parameters["modality"] == "rad":
Expand All @@ -131,7 +141,7 @@ def inference_loop(
# Setting up the inference loader
inference_loader = get_testing_loader(parameters)

print("Data Samples: ", len(inference_loader.dataset), flush=True)
logger.debug("Data Samples: ", len(inference_loader.dataset))

average_epoch_valid_loss, average_epoch_valid_metric = validate_network(
model, inference_loader, None, parameters, mode="inference"
Expand Down Expand Up @@ -287,8 +297,8 @@ def inference_loop(
# Check if out_probs_map is greater than 1, print a warning
if np.max(probs_map) > 1:
# Print a warning
print(
"Warning: Probability map is greater than 1, report the images to GaNDLF developers"
logger.warning(
"Probability map is greater than 1, report the images to GaNDLF developers"
)

if count_map is not None:
Expand Down
7 changes: 3 additions & 4 deletions GANDLF/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import warnings
from typing import Dict, Tuple
from GANDLF.losses import global_losses_dict
from GANDLF.metrics import global_metrics_dict
Expand Down Expand Up @@ -63,9 +63,8 @@ def get_loss_and_metrics(
if loss_str_lower in global_losses_dict:
loss_function = global_losses_dict[loss_str_lower]
else:
sys.exit(
"WARNING: Could not find the requested loss function '"
+ params["loss_function"]
warnings.warn(
"Could not find the requested loss function '" + params["loss_function"]
)

loss = 0
Expand Down