diff --git a/gandlf_collectStats b/gandlf_collectStats index 2b73c59e3..79fc45e8f 100644 --- a/gandlf_collectStats +++ b/gandlf_collectStats @@ -1,20 +1,141 @@ -#!usr/bin/env python +#!/usr/bin/env python # -*- coding: utf-8 -*- import os import argparse -import ast -from pathlib import Path import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +from pathlib import Path from io import StringIO +from GANDLF.cli import copyrightMessage +from GANDLF.utils.plot_utils import plot_all + +import os import seaborn as sns import matplotlib.pyplot as plt +from pathlib import Path -from GANDLF.cli import copyrightMessage + +def plot_all(df_training, df_validation, df_testing, output_plot_dir): + """ + Plots training, validation, and testing data for loss and other metrics. + + Args: + df_training (pd.DataFrame): DataFrame containing training data. + df_validation (pd.DataFrame): DataFrame containing validation data. + df_testing (pd.DataFrame): DataFrame containing testing data. + output_plot_dir (str): Directory to save the plots. + + Returns: + tuple: Tuple containing the modified training, validation, and testing DataFrames. + """ + # Drop any columns that might have "_" in the values of their rows + banned_cols = [ + col + for col in df_training.columns + if any("_" in str(val) for val in df_training[col].values) + ] + + # Determine metrics from the column names by removing the "train_" prefix + metrics = [ + col.replace("train_", "") + for col in df_training.columns + if "train_" in col and col not in banned_cols + ] + + # Split the values of the banned columns into multiple columns + # for df in [df_training, df_validation, df_testing]: + # for col in banned_cols: + # if df[col].dtype == "object": + # split_cols = ( + # df[col] + # .str.split("_", expand=True) + # .apply(pd.to_numeric, errors="coerce") + # ) + # split_cols.columns = [f"{col}_{i}" for i in range(split_cols.shape[1])] + # df.drop(columns=col, inplace=True) + # df = pd.concat([df, split_cols], axis=1) + + # Check if any of the metrics is present in the column names of the dataframe + assert any( + any(metric in col for col in df_training.columns) for metric in metrics + ), "None of the specified metrics is in the dataframe." + + required_cols = ["epoch_no", "train_loss"] + + # Check if the required columns are in the dataframe + assert all( + col in df_training.columns for col in required_cols + ), "Not all required columns are in the dataframe." + + epochs = len(df_training) + + # Plot for loss + plt.figure(figsize=(12, 6)) + if "train_loss" in df_training.columns: + sns.lineplot(data=df_training, x="epoch_no", y="train_loss", label="Training") + + if "valid_loss" in df_validation.columns: + sns.lineplot( + data=df_validation, x="epoch_no", y="valid_loss", label="Validation" + ) + + if df_testing is not None and "test_loss" in df_testing.columns: + sns.lineplot(data=df_testing, x="epoch_no", y="test_loss", label="Testing") + + plt.xlim(0, epochs - 1) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("Loss Plot") + plt.legend() + Path(output_plot_dir).mkdir(parents=True, exist_ok=True) + plt.savefig(os.path.join(output_plot_dir, "loss_plot.png"), dpi=300) + plt.close() + + # Plot for other metrics + for metric in metrics: + metric_cols = [col for col in df_training.columns if metric in col] + for metric_col in metric_cols: + plt.figure(figsize=(12, 6)) + if metric_col in df_training.columns: + sns.lineplot( + data=df_training, + x="epoch_no", + y=metric_col, + label=f"Training {metric_col}", + ) + if metric_col.replace("train", "valid") in df_validation.columns: + sns.lineplot( + data=df_validation, + x="epoch_no", + y=metric_col.replace("train", "valid"), + label=f"Validation {metric_col}", + ) + if ( + df_testing is not None + and metric_col.replace("train", "test") in df_testing.columns + ): + sns.lineplot( + data=df_testing, + x="epoch_no", + y=metric_col.replace("train", "test"), + label=f"Testing {metric_col}", + ) + plt.xlim(0, epochs - 1) + plt.xlabel("Epoch") + plt.ylabel(metric.capitalize()) + plt.title(f"{metric.capitalize()} Plot") + plt.legend() + plt.savefig(os.path.join(output_plot_dir, f"{metric}_plot.png"), dpi=300) + plt.close() + + print("Plots saved successfully.") + return df_training, df_validation, df_testing -def main(): +if __name__ == "__main__": parser = argparse.ArgumentParser( prog="GANDLF_CollectStats", formatter_class=argparse.RawTextHelpFormatter, @@ -26,7 +147,7 @@ def main(): "--modeldir", metavar="", type=str, - help="Input directory which contains testing and validation models", + help="Input directory which contains testing and validation models log files", ) parser.add_argument( "-o", @@ -35,14 +156,6 @@ def main(): type=str, help="Output directory to save stats and plot", ) - parser.add_argument( - "-c", - "--combinedplots", - metavar="", - default=False, - type=ast.literal_eval, - help="Overlays training and validation plots for both accuracy and loss (classification only).", - ) args = parser.parse_args() @@ -52,202 +165,14 @@ def main(): outputFile = os.path.join(outputDir, "data.csv") # data file name outputPlot = os.path.join(outputDir, "plot.png") # plot file - combinedPlots = args.combinedplots - trainingLogs = os.path.join(inputDir, "logs_training.csv") validationLogs = os.path.join(inputDir, "logs_validation.csv") testingLogs = os.path.join(inputDir, "logs_testing.csv") - if os.path.exists(testingLogs): - testingLogsCSV = pd.read_csv(testingLogs) - - # check for classification task - if len(testingLogsCSV) == 0: - print("Classification task detected, generating accuracy and loss plots.") - - # check whether user wants training + validation overlaid plots - if combinedPlots: - df_training = pd.read_csv(trainingLogs) - df_validation = pd.read_csv(validationLogs) + # Read all the files + df_training = pd.read_csv(trainingLogs) + df_validation = pd.read_csv(validationLogs) + df_testing = pd.read_csv(testingLogs) if os.path.isfile(testingLogs) else None - epochs = len(df_training) - - fig, axes = plt.subplots(nrows=1, ncols=2) # set plot properties - # ensure spacing between plots - plt.subplots_adjust(wspace=0.5, hspace=0.5) - # plot training accuracy data - splot = sns.lineplot( - data=df_training, - x="epoch_no", - y="train_balanced_accuracy", - ax=axes[0], - ) - # plot validation accuracy data - splot = sns.lineplot( - data=df_validation, - x="epoch_no", - y="valid_balanced_accuracy", - ax=axes[0], - ) - # set limits for x-axis for proper visualization - splot.set(xlim=(0, epochs - 1)) - # set limits for y-axis for proper visualization - splot.set(ylim=(0, 1)) - # add labels and title to plot - splot.set(xlabel="Epoch", ylabel="Accuracy", title="Accuracy Plot") - # add legend to plot - axes[0].legend(labels=["Training", "Validation"]) - - # plot training loss data - splot = sns.lineplot( - data=df_training, x="epoch_no", y="train_loss", ax=axes[1] - ) - # plot validation loss data - splot = sns.lineplot( - data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1] - ) - # set limits for x-axis for proper visualization - splot.set(xlim=(0, epochs - 1)) - # add labels and title to plot - splot.set(xlabel="Epoch", ylabel="Loss", title="Loss Plot") - # add legend to plot - axes[1].legend(labels=["Training", "Validation"]) - # save plot - plt.savefig(outputPlot, dpi=600) - - print("Plots saved successfully.") - - else: - df_training = pd.read_csv(trainingLogs) - df_validation = pd.read_csv(validationLogs) - - epochs = len(df_training) - - # set plot properties - fig, axes = plt.subplots(nrows=2, ncols=2) - - plt.subplots_adjust(wspace=0.5, hspace=0.5) - # plot the data - splot = sns.lineplot( - data=df_training, - x="epoch_no", - y="train_balanced_accuracy", - ax=axes[0, 0], - ) - splot.set(xlim=(0, epochs - 1)) - splot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization - # set labels - splot.set( - xlabel="Epoch", ylabel="Accuracy", title="Training Accuracy Plot" - ) - - # plot the data - splot = sns.lineplot( - data=df_validation, - x="epoch_no", - y="valid_balanced_accuracy", - ax=axes[0, 1], - ) - splot.set(xlim=(0, epochs - 1)) - splot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization - # set labels - splot.set( - xlabel="Epoch", ylabel="Accuracy", title="Validation Accuracy Plot" - ) - # plot the data - splot = sns.lineplot( - data=df_training, x="epoch_no", y="train_loss", ax=axes[1, 0] - ) - splot.set(xlim=(0, epochs - 1)) - # set labels - splot.set(xlabel="Epoch", ylabel="Loss", title="Training Loss Plot") - # plot the data - splot = sns.lineplot( - data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1, 1] - ) - splot.set(xlim=(0, epochs - 1)) - # set labels - splot.set(xlabel="Epoch", ylabel="Loss", title="Validation Loss Plot") - - plt.savefig(outputPlot, dpi=600) - - print("Plots saved successfully.") - - else: - print("Segmentation task detected, generating dice and loss plots.") - - final_stats = "Epoch,Train_Loss,Train_Dice,Val_Loss,Val_Dice,Testing_Loss,Testing_Dice\n" # the columns that need to be present in final output; epoch is always removed - - # loop through output directory - for dirs in os.listdir(inputDir): - currentTestingDir = os.path.join(inputDir, dirs) - if os.path.isdir(currentTestingDir): # go in only if it is a directory - if "testing_" in dirs: # ensure it is part of the testing structure - # loop through all validation directories - for val in os.listdir(currentTestingDir): - currentValidationDir = os.path.join(currentTestingDir, val) - if os.path.isdir(currentValidationDir): - # get all files in each directory - filesInDir = os.listdir(currentValidationDir) - - for i, n in enumerate(filesInDir): - # when the log has been found, collect the final numbers - if "trainingScores_log" in n: - log_file = os.path.join(currentValidationDir, n) - with open(log_file) as f: - for line in f: - pass - final_stats = final_stats + line - - data_string = StringIO(final_stats) - data_full = pd.read_csv(data_string, sep=",") - del data_full["Epoch"] # no need for epoch - data_full.to_csv(outputFile, index=False) # save updated data - - # perform deep copy - data_loss = data_full.copy() - data_dice = data_full.copy() - # set the datasets that need to be plotted - cols = [ - "Train", - "Val", - "Testing", - ] - for i in cols: - del data_dice[i + "_Loss"] # keep only dice - del data_loss[i + "_Dice"] # keep only loss - # rename the columns - data_loss.rename(columns={i + "_Loss": i}, inplace=True) - # rename the columns - data_dice.rename(columns={i + "_Dice": i}, inplace=True) - # set plot properties - fig, axes = plt.subplots(nrows=1, ncols=2, constrained_layout=True) - # plot the data - bplot = sns.boxplot( - data=data_dice, width=0.5, palette="colorblind", ax=axes[0] - ) - # set limits for y-axis for proper visualization - bplot.set(ylim=(0, 1)) - # set labels - bplot.set(xlabel="Dataset", ylabel="Dice", title="Dice plot") - # rotate so that everything is visible - bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right") - # plot the data - bplot = sns.boxplot( - data=data_loss, width=0.5, palette="colorblind", ax=axes[1] - ) - # set limits for y-axis for proper visualization - bplot.set(ylim=(0, 1)) - # set labels - bplot.set(xlabel="Dataset", ylabel="Loss", title="Loss plot") - # rotate so that everything is visible - bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right") - - plt.savefig(outputPlot, dpi=600) - - print("Plots saved successfully.") - - -# main function -if __name__ == "__main__": - main() + # Check for metrics in columns and do tight plots + plot_all(df_training, df_validation, df_testing, outputPlot)