-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IN PROGRESS Logging setup #842
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||||||||||||||||
import os | ||||||||||||||||||||
import pathlib | ||||||||||||||||||||
from typing import Optional, Tuple | ||||||||||||||||||||
|
||||||||||||||||||||
import logging | ||||||||||||||||||||
import numpy as np | ||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||
import SimpleITK as sitk | ||||||||||||||||||||
|
@@ -19,6 +19,7 @@ | |||||||||||||||||||
reverse_one_hot, | ||||||||||||||||||||
get_ground_truths_and_predictions_tensor, | ||||||||||||||||||||
print_and_format_metrics, | ||||||||||||||||||||
setup_logger, | ||||||||||||||||||||
) | ||||||||||||||||||||
from GANDLF.metrics import overall_stats | ||||||||||||||||||||
from tqdm import tqdm | ||||||||||||||||||||
|
@@ -46,6 +47,13 @@ def validate_network( | |||||||||||||||||||
Returns: | ||||||||||||||||||||
Tuple[float, dict]: The average validation loss and the average validation metrics. | ||||||||||||||||||||
""" | ||||||||||||||||||||
if "logger_name" in params: | ||||||||||||||||||||
logger = logging.getLogger(params["logger_name"]) | ||||||||||||||||||||
else: | ||||||||||||||||||||
logger, params["logs_dir"], params["logger_name"] = setup_logger( | ||||||||||||||||||||
output_dir=params["output_dir"], verbose=params.get("verbose", False) | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
print("*" * 20) | ||||||||||||||||||||
print("Starting " + mode + " : ") | ||||||||||||||||||||
print("*" * 20) | ||||||||||||||||||||
|
@@ -70,9 +78,8 @@ def validate_network( | |||||||||||||||||||
is_inference = mode == "inference" | ||||||||||||||||||||
|
||||||||||||||||||||
# automatic mixed precision - https://pytorch.org/docs/stable/amp.html | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
if params["model"]["amp"]: | ||||||||||||||||||||
print("Using Automatic mixed precision", flush=True) | ||||||||||||||||||||
if params["model"]["amp"]: | ||||||||||||||||||||
logger.debug("Using Automatic mixed precision") | ||||||||||||||||||||
|
||||||||||||||||||||
if scheduler is None: | ||||||||||||||||||||
current_output_dir = params["output_dir"] # this is in inference mode | ||||||||||||||||||||
|
@@ -115,8 +122,7 @@ def validate_network( | |||||||||||||||||||
for batch_idx, (subject) in enumerate( | ||||||||||||||||||||
tqdm(valid_dataloader, desc="Looping over " + mode + " data") | ||||||||||||||||||||
): | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
print("== Current subject:", subject["subject_id"], flush=True) | ||||||||||||||||||||
logger.debug("== Current subject:", subject["subject_id"]) | ||||||||||||||||||||
|
||||||||||||||||||||
# ensure spacing is always present in params and is always subject-specific | ||||||||||||||||||||
params["subject_spacing"] = None | ||||||||||||||||||||
|
@@ -247,16 +253,14 @@ def validate_network( | |||||||||||||||||||
output_prediction = 0 # this is used for regression/classification | ||||||||||||||||||||
current_patch = 0 | ||||||||||||||||||||
for patches_batch in patch_loader: | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
print( | ||||||||||||||||||||
"=== Current patch:", | ||||||||||||||||||||
current_patch, | ||||||||||||||||||||
", time : ", | ||||||||||||||||||||
get_date_time(), | ||||||||||||||||||||
", location :", | ||||||||||||||||||||
patches_batch[torchio.LOCATION], | ||||||||||||||||||||
flush=True, | ||||||||||||||||||||
) | ||||||||||||||||||||
logger.debug( | ||||||||||||||||||||
"=== Current patch:", | ||||||||||||||||||||
current_patch, | ||||||||||||||||||||
", time : ", | ||||||||||||||||||||
get_date_time(), | ||||||||||||||||||||
", location :", | ||||||||||||||||||||
patches_batch[torchio.LOCATION], | ||||||||||||||||||||
) | ||||||||||||||||||||
Comment on lines
+256
to
+263
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
or maybe even we don't need time here as it is already included in logging format |
||||||||||||||||||||
current_patch += 1 | ||||||||||||||||||||
image = ( | ||||||||||||||||||||
torch.cat( | ||||||||||||||||||||
|
@@ -279,14 +283,12 @@ def validate_network( | |||||||||||||||||||
|
||||||||||||||||||||
if label is not None: | ||||||||||||||||||||
label = label.to(params["device"]) | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
print( | ||||||||||||||||||||
"=== Validation shapes : label:", | ||||||||||||||||||||
label.shape, | ||||||||||||||||||||
", image:", | ||||||||||||||||||||
image.shape, | ||||||||||||||||||||
flush=True, | ||||||||||||||||||||
) | ||||||||||||||||||||
logger.debug( | ||||||||||||||||||||
"=== Validation shapes : label:", | ||||||||||||||||||||
label.shape, | ||||||||||||||||||||
", image:", | ||||||||||||||||||||
image.shape, | ||||||||||||||||||||
) | ||||||||||||||||||||
Comment on lines
+286
to
+291
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
|
||||||||||||||||||||
if is_inference: | ||||||||||||||||||||
result = step(model, image, None, params, train=False) | ||||||||||||||||||||
|
@@ -427,14 +429,12 @@ def validate_network( | |||||||||||||||||||
output_prediction.to(torch.float32), | ||||||||||||||||||||
params, | ||||||||||||||||||||
) | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
print( | ||||||||||||||||||||
"Full image " + mode + ":: Loss: ", | ||||||||||||||||||||
final_loss, | ||||||||||||||||||||
"; Metric: ", | ||||||||||||||||||||
final_metric, | ||||||||||||||||||||
flush=True, | ||||||||||||||||||||
) | ||||||||||||||||||||
logger.debug( | ||||||||||||||||||||
"Full image " + mode + ":: Loss: ", | ||||||||||||||||||||
final_loss, | ||||||||||||||||||||
"; Metric: ", | ||||||||||||||||||||
final_metric, | ||||||||||||||||||||
) | ||||||||||||||||||||
Comment on lines
+432
to
+437
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Because you know logging utils are not like a print, it fails if you pass multiple values instead of printing them together |
||||||||||||||||||||
|
||||||||||||||||||||
# # Non network validing related | ||||||||||||||||||||
# loss.cpu().data.item() | ||||||||||||||||||||
|
@@ -453,28 +453,24 @@ def validate_network( | |||||||||||||||||||
total_epoch_valid_metric[metric] += final_metric[metric] | ||||||||||||||||||||
|
||||||||||||||||||||
if label_ground_truth is not None: | ||||||||||||||||||||
if params["verbose"]: | ||||||||||||||||||||
# For printing information at halftime during an epoch | ||||||||||||||||||||
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and ( | ||||||||||||||||||||
(batch_idx + 1) < len(valid_dataloader) | ||||||||||||||||||||
): | ||||||||||||||||||||
print( | ||||||||||||||||||||
"\nHalf-Epoch Average " + mode + " loss : ", | ||||||||||||||||||||
total_epoch_valid_loss / (batch_idx + 1), | ||||||||||||||||||||
# For printing information at halftime during an epoch | ||||||||||||||||||||
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and ( | ||||||||||||||||||||
(batch_idx + 1) < len(valid_dataloader) | ||||||||||||||||||||
): | ||||||||||||||||||||
logger.debug( | ||||||||||||||||||||
"\nHalf-Epoch Average " + mode + " loss : ", | ||||||||||||||||||||
total_epoch_valid_loss / (batch_idx + 1), | ||||||||||||||||||||
) | ||||||||||||||||||||
Comment on lines
+460
to
+463
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
for metric in params["metrics"]: | ||||||||||||||||||||
if isinstance(total_epoch_valid_metric[metric], np.ndarray): | ||||||||||||||||||||
to_print = ( | ||||||||||||||||||||
total_epoch_valid_metric[metric] / (batch_idx + 1) | ||||||||||||||||||||
).tolist() | ||||||||||||||||||||
else: | ||||||||||||||||||||
to_print = total_epoch_valid_metric[metric] / (batch_idx + 1) | ||||||||||||||||||||
logger.debug( | ||||||||||||||||||||
"Half-Epoch Average " + mode + " " + metric + " : ", to_print | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(or even rewrite it to |
||||||||||||||||||||
) | ||||||||||||||||||||
for metric in params["metrics"]: | ||||||||||||||||||||
if isinstance(total_epoch_valid_metric[metric], np.ndarray): | ||||||||||||||||||||
to_print = ( | ||||||||||||||||||||
total_epoch_valid_metric[metric] / (batch_idx + 1) | ||||||||||||||||||||
).tolist() | ||||||||||||||||||||
else: | ||||||||||||||||||||
to_print = total_epoch_valid_metric[metric] / ( | ||||||||||||||||||||
batch_idx + 1 | ||||||||||||||||||||
) | ||||||||||||||||||||
print( | ||||||||||||||||||||
"Half-Epoch Average " + mode + " " + metric + " : ", | ||||||||||||||||||||
to_print, | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
if params["medcam_enabled"] and params["model"]["type"] == "torch": | ||||||||||||||||||||
model.disable_medcam() | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.