Skip to content

Commit

Permalink
Merge branch 'master' into sarthakpati-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
Geeks-Sid committed Mar 22, 2022
2 parents e433051 + 2ae23ff commit e3578b9
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 120 deletions.
198 changes: 104 additions & 94 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from GANDLF.utils import (
get_date_time,
get_unique_timestamp,
get_filename_extension_sanitized,
reverse_one_hot,
)
Expand Down Expand Up @@ -86,18 +87,15 @@ def validate_network(
model.enable_medcam()
params["medcam_enabled"] = True

if params["save_output"]:
if "value_keys" in params:
if params["save_output"] or is_inference:
if params["problem_type"] != "segmentation":
outputToWrite = "Epoch,SubjectID,PredictedValue\n"
file_to_write = os.path.join(current_output_dir, "output_predictions.csv")
if os.path.exists(file_to_write):
# append to previously generated file
file = open(file_to_write, "a")
outputToWrite = ""
else:
# if file was absent, write header information
file = open(file_to_write, "w")
# used to write output
outputToWrite = "Epoch,SubjectID,PredictedValue\n"
file_to_write = os.path.join(
current_output_dir,
"output_predictions_" + get_unique_timestamp() + ".csv",
)

for batch_idx, (subject) in enumerate(
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
Expand Down Expand Up @@ -143,7 +141,7 @@ def validate_network(
)

# regression/classification problem AND label is present
if ("value_keys" in params) and label_present:
if (params["problem_type"] != "segmentation") and label_present:
sampler = torchio.data.LabelSampler(params["patch_size"])
tio_subject = torchio.Subject(subject_dict)
generator = sampler(tio_subject, num_patches=params["q_samples_per_volume"])
Expand Down Expand Up @@ -175,14 +173,12 @@ def validate_network(

pred_output = pred_output.cpu() / params["q_samples_per_volume"]
pred_output /= params["scaling_factor"]
# all_predics.append(pred_output.double())
# all_targets.append(valuesToPredict.double())

if is_inference and is_classification:
logits_list.append(pred_output)
subject_id_list.append(subject.get("subject_id")[0])

if params["save_output"]:
if params["save_output"] or is_inference:
outputToWrite += (
str(epoch)
+ ","
Expand Down Expand Up @@ -216,7 +212,6 @@ def validate_network(

output_prediction = 0 # this is used for regression/classification
current_patch = 0
is_segmentation = True
for patches_batch in patch_loader:
if params["verbose"]:
print(
Expand All @@ -240,23 +235,27 @@ def validate_network(
.float()
.to(params["device"])
)
if "value_keys" in params:
is_segmentation = False

# calculate metrics if ground truth is present
label = None
if params["problem_type"] != "segmentation":
label = label_ground_truth
else:
label = patches_batch["label"][torchio.DATA]
label = label.to(params["device"])
if params["verbose"]:
print(
"=== Validation shapes : label:",
label.shape,
", image:",
image.shape,
flush=True,
)

if label is not None:
label = label.to(params["device"])
if params["verbose"]:
print(
"=== Validation shapes : label:",
label.shape,
", image:",
image.shape,
flush=True,
)

if is_inference:
result = step(model, image, label, params, train=False)
result = step(model, image, None, params, train=False)
else:
result = step(model, image, label, params, train=True)

Expand All @@ -269,7 +268,7 @@ def validate_network(
else:
_, _, output = result

if is_segmentation:
if params["problem_type"] == "segmentation":
aggregator.add_batch(
output.detach().cpu(), patches_batch[torchio.LOCATION]
)
Expand All @@ -281,19 +280,16 @@ def validate_network(
output_prediction += output

# save outputs
if is_segmentation:
if params["problem_type"] == "segmentation":
output_prediction = aggregator.get_output_tensor()
output_prediction = output_prediction.unsqueeze(0)
label_ground_truth = label_ground_truth.unsqueeze(0).to(torch.float32)
if params["save_output"]:
img_for_metadata = torchio.Image(
type=subject["label"]["type"],
tensor=subject["label"]["data"].squeeze(0),
affine=subject["label"]["affine"].squeeze(0),
type=subject["1"]["type"],
tensor=subject["1"]["data"].squeeze(0),
affine=subject["1"]["affine"].squeeze(0),
).as_sitk()
ext = get_filename_extension_sanitized(
subject["path_to_metadata"][0]
)
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
pred_mask = output_prediction.numpy()
# '0' because validation/testing dataloader always has batch size of '1'
pred_mask = reverse_one_hot(
Expand All @@ -312,9 +308,7 @@ def validate_network(
result_image.CopyInformation(img_for_metadata)

# cast as the same data type
result_image = sitk.Cast(
result_image, img_for_metadata.GetPixelID()
)
result_image = sitk.Cast(result_image, sitk.sitkUInt16)
# this handles cases that need resampling/resizing
if "resample" in params["data_preprocessing"]:
resampler = torchio.transforms.Resample(
Expand Down Expand Up @@ -355,68 +349,80 @@ def validate_network(
subject_id_list.append(subject.get("subject_id")[0])

# we cast to float32 because float16 was causing nan
final_loss, final_metric = get_loss_and_metrics(
image,
label_ground_truth,
output_prediction.to(torch.float32),
params,
)
if params["verbose"]:
print(
"Full image " + mode + ":: Loss: ",
final_loss,
"; Metric: ",
final_metric,
flush=True,
if label_ground_truth is not None:
# this is for RGB label
if label_ground_truth.shape[0] == 3:
label_ground_truth = label_ground_truth[0, ...].unsqueeze(0)
# we always want the ground truth to be in the same format as the prediction
label_ground_truth = label_ground_truth.unsqueeze(0)
if label_ground_truth.shape[-1] == 1:
label_ground_truth = label_ground_truth.squeeze(-1)
final_loss, final_metric = get_loss_and_metrics(
image,
label_ground_truth,
output_prediction.to(torch.float32),
params,
)
if params["verbose"]:
print(
"Full image " + mode + ":: Loss: ",
final_loss,
"; Metric: ",
final_metric,
flush=True,
)

# # Non network validing related
# loss.cpu().data.item()
total_epoch_valid_loss += final_loss.cpu().item()
for metric in final_metric.keys():
if isinstance(total_epoch_valid_metric[metric], list):
total_epoch_valid_metric[metric].append(final_metric[metric])
else:
total_epoch_valid_metric[metric] += final_metric[metric]
# # Non network validing related
# loss.cpu().data.item()
total_epoch_valid_loss += final_loss.cpu().item()
for metric in final_metric.keys():
if isinstance(total_epoch_valid_metric[metric], list):
total_epoch_valid_metric[metric].append(final_metric[metric])
else:
total_epoch_valid_metric[metric] += final_metric[metric]

# For printing information at halftime during an epoch
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and (
(batch_idx + 1) < len(valid_dataloader)
):
print(
"\nHalf-Epoch Average " + mode + " loss : ",
total_epoch_valid_loss / (batch_idx + 1),
)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], list):
to_print = (
np.array(total_epoch_valid_metric[metric]) / (batch_idx + 1)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / (batch_idx + 1)
if label_ground_truth is not None:
# For printing information at halftime during an epoch
if ((batch_idx + 1) % (len(valid_dataloader) / 2) == 0) and (
(batch_idx + 1) < len(valid_dataloader)
):
print(
"Half-Epoch Average " + mode + " " + metric + " : ",
to_print,
"\nHalf-Epoch Average " + mode + " loss : ",
total_epoch_valid_loss / (batch_idx + 1),
)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], list):
to_print = (
np.array(total_epoch_valid_metric[metric]) / (batch_idx + 1)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / (batch_idx + 1)
print(
"Half-Epoch Average " + mode + " " + metric + " : ",
to_print,
)

if params["medcam_enabled"] and params["model"]["type"] == "torch":
model.disable_medcam()
params["medcam_enabled"] = False

average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader)
print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], list):
to_print = (
np.array(total_epoch_valid_metric[metric]) / len(valid_dataloader)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / len(valid_dataloader)
average_epoch_valid_metric[metric] = to_print
print(
" Epoch Final " + mode + " " + metric + " : ",
average_epoch_valid_metric[metric],
)
if label_ground_truth is not None:
average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader)
print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], list):
to_print = (
np.array(total_epoch_valid_metric[metric]) / len(valid_dataloader)
).tolist()
else:
to_print = total_epoch_valid_metric[metric] / len(valid_dataloader)
average_epoch_valid_metric[metric] = to_print
print(
" Epoch Final " + mode + " " + metric + " : ",
average_epoch_valid_metric[metric],
)
else:
average_epoch_valid_loss, average_epoch_valid_metric = 0, {}

if scheduler is not None:
if params["scheduler"]["type"] in [
Expand All @@ -441,11 +447,15 @@ def validate_network(
logits_df.SubjectID = subject_id_list
logits_df[class_list] = logit_tensor

logits_df.to_csv(
os.path.join(current_fold_dir, "logits.csv"), index=False, sep=","
)
logits_file = os.path.join(current_fold_dir, "logits.csv")
if os.path.isfile(logits_file):
logits_file = os.path.join(
current_fold_dir, "logits_" + get_unique_timestamp() + ".csv"
)
logits_df.to_csv(logits_file, index=False, sep=",")

if "value_keys" in params:
file = open(file_to_write, "w")
file.write(outputToWrite)
file.close()

Expand Down
35 changes: 20 additions & 15 deletions GANDLF/compute/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ def step(model, image, label, params, train=True):
)

# for the weird cases where mask is read as an RGB image, ensure only the first channel is used
if params["problem_type"] == "segmentation":
if label.shape[1] == 3:
label = label[:, 0, ...].unsqueeze(1)
# this warning should only come up once
if params["print_rgb_label_warning"]:
print(
"WARNING: The label image is an RGB image, only the first channel will be used.",
flush=True,
)
params["print_rgb_label_warning"] = False
if label is not None:
if params["problem_type"] == "segmentation":
if label.shape[1] == 3:
label = label[:, 0, ...].unsqueeze(1)
# this warning should only come up once
if params["print_rgb_label_warning"]:
print(
"WARNING: The label image is an RGB image, only the first channel will be used.",
flush=True,
)
params["print_rgb_label_warning"] = False

if params["model"]["dimension"] == 2:
label = torch.squeeze(label, -1)
if params["model"]["dimension"] == 2:
label = torch.squeeze(label, -1)

if params["model"]["dimension"] == 2:
image = torch.squeeze(image, -1)
if "value_keys" in params:
if len(label.shape) > 1:
label = torch.squeeze(label, -1)
if label is not None:
if len(label.shape) > 1:
label = torch.squeeze(label, -1)

if train == False and params["model"]["type"].lower() == "openvino":
output = torch.from_numpy(
Expand All @@ -78,7 +80,10 @@ def step(model, image, label, params, train=True):
output, attention_map = output

# one-hot encoding of 'label' will probably be needed for segmentation
loss, metric_output = get_loss_and_metrics(image, label, output, params)
if label is not None:
loss, metric_output = get_loss_and_metrics(image, label, output, params)
else:
loss, metric_output = None, None

if len(output) > 1:
output = output[0]
Expand Down
21 changes: 11 additions & 10 deletions GANDLF/inference_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from GANDLF.compute import inference_loop
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import pandas as pd

from GANDLF.compute import inference_loop
from GANDLF.utils import get_unique_timestamp


def InferenceManager(dataframe, outputDir, parameters, device):
Expand Down Expand Up @@ -74,10 +75,10 @@ def InferenceManager(dataframe, outputDir, parameters, device):
averaged_probs_df.PredictedClass = [
class_list[a] for a in averaged_probs.argmax(1)
]
averaged_probs_df.to_csv(
os.path.join(
outputDir, "final_predictions_with_averaged_probabilities.csv"
),
index=False,
sep=",",
)
filepath_to_save = os.path.join(outputDir, "final_preds_and_avg_probs.csv")
if os.path.isfile(filepath_to_save):
filepath_to_save = os.path.join(
outputDir,
"final_preds_and_avg_probs" + get_unique_timestamp() + ".csv",
)
averaged_probs_df.to_csv(filepath_to_save, index=False)
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .generic import (
fix_paths,
get_date_time,
get_unique_timestamp,
get_filename_extension_sanitized,
version_check,
)
Expand Down
Loading

0 comments on commit e3578b9

Please sign in to comment.