## 📊 W&B Data Extraction, Processing, and Excel Export 📂

#### This notebook contains code to:
- Retrieve data from **Weights & Biases** using the API
- Perform minimal processing on the data to obtain best metric values and metric values at the epoch where a certain metric is maximized (e.g. Mean Dice)
- Export the processed data to **Excel** for further analysis


---

#### 🔬 This cell is where you specify the experiments that you would like to analyze

In [39]:
import wandb

api = wandb.Api()

entity = 'brats_dann'   # Keep the same
project = 'Debugging loss codes'  # Your project name
experiment_name = 'isbi24'  # text segment to identify experiment

runs = api.runs(path=f'{entity}/{project}')

run_ids_dict = {}
for i in runs:
    if experiment_name in i.name:
        print("run name = ",i.name," id: ", i.id)
        fold_no = i.name.split('-test')[0].split('-')[-1]
        fold_no = 'fold'+ str([part for part in i.name.split('-') if 'fold' in part][0].replace('fold', ''))
        run_ids_dict[fold_no] = i.id
run_ids_dict

run name =  UNet-fold1-test3_boxloss_isbi24  id:  kqgjghul
run name =  UNet-fold2-test3_boxloss_isbi24  id:  rmugedvf
run name =  UNet-fold0-test3_boxloss_isbi24  id:  reupo79e
run name =  UNet-fold4-test3_boxloss_isbi24  id:  f5hu42jk
run name =  UNet-fold3-test3_boxloss_isbi24  id:  jra6we3x


{'fold1': 'kqgjghul',
 'fold2': 'rmugedvf',
 'fold0': 'reupo79e',
 'fold4': 'f5hu42jk',
 'fold3': 'jra6we3x'}

#### 📈 This cell contains a function that allows you to fetch and plot metrics of interest.

In [44]:
import wandb
import matplotlib.pyplot as plt
import math
import os

def fetch_wandb_metrics(entity, project, run_id, horizontal_axis, vertical_axes, show_plots = True):
    """
    Fetch data from Weights and Biases (W&B) and optionally plot multiple vertical axes against a common horizontal axis.
    Plots are arranged in rows of 3 columns.

    Parameters:
    entity (str): W&B entity (username or team).
    project (str): W&B project name.
    run_id (str): W&B run ID.
    horizontal_axis (str): Column to use for the horizontal axis.
    vertical_axes (list): List of columns to use for the vertical axes.
    show_plots (bool): Whether to show plots

    Returns:
    results_dict (dict): Dictionary where each key is a vertical axis and the value is the DataFrame of the queried data for that axis.
    """
    # Initialize the W&B API
    api = wandb.Api()

    run = api.run(f"{entity}/{project}/{run_id}")

    history = run.history(samples=100000)  # Adjust the sample size if needed

    columns_to_query = [horizontal_axis] + vertical_axes
    available_columns = [col for col in columns_to_query if col in history.columns]

    if not available_columns:
        print(f"None of the specified columns were found in the run history: {columns_to_query}")
        return None
    
    # Query the data for the requested columns
    queried_data = history[available_columns]

    # Consolidate data by grouping by the horizontal axis and dropping NaN values
    consolidated_data = queried_data.groupby(horizontal_axis).agg(
        lambda x: x.dropna().iloc[0] if not x.dropna().empty else None
    ).reset_index()
    queried_data = consolidated_data.dropna(how='all', subset=available_columns[1:])

    # Dictionary to hold data for each vertical axis
    results_dict = {}
    for vertical_axis in vertical_axes:
        results_dict[vertical_axis] = queried_data[[horizontal_axis, vertical_axis]].dropna()

    # Set plot style
    plt.style.use('fivethirtyeight')

    # Determine the number of rows and columns for subplots (max 3 columns per row)
    num_plots = len(vertical_axes)
    num_cols = min(3, num_plots)  # Max 3 columns
    num_rows = math.ceil(num_plots / 3)  # Calculate the required number of rows

    # Create subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(6 * num_cols, 6 * num_rows))

    # Handle cases where we have a single row of subplots
    if num_plots == 1:
        axes = [axes]  # Single plot, make it iterable
    elif num_rows == 1:
        axes = axes.flatten()  # Flatten if only one row of subplots
    else:
        axes = axes.flatten()  # Flatten the 2D array of axes for easier iteration

    # Plot each vertical axis in its corresponding subplot
    for i, vertical_axis in enumerate(vertical_axes):
        axes[i].plot(results_dict[vertical_axis][horizontal_axis], results_dict[vertical_axis][vertical_axis], marker='o', linestyle='-', linewidth=2)
        axes[i].set_ylabel(vertical_axis, fontsize=14)
        axes[i].set_xlabel(horizontal_axis, fontsize=14)
        axes[i].grid(True, which='both', linestyle='--', linewidth=0.7)
        axes[i].set_title(vertical_axis, fontsize=14, fontweight='bold')

    # Hide any unused subplots (if the number of plots is less than the grid size)
    for i in range(num_plots, num_rows * num_cols):
        fig.delaxes(axes[i])

    # Display the plot
    plt.suptitle(run.name)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    
    if not show_plots:
        plt.close()

    # Return the dictionary containing the queried data
    return results_dict

#### Run this cell to define metrics_dicts. You can show plots by setting show_plots = True

In [41]:
# Example usage
metrics_dicts = {}
for fold_no in range(5):
    entity = 'brats_dann'
    project = 'Debugging loss codes'
    # project ='BraTS Goat 5-fold Test'
    fold_name = f'fold{fold_no}'
    run_id = run_ids_dict[fold_name]
    horizontal_axis = 'epoch'
    vertical_axes = ['mean_Dice_val', 'mean_HD_val', 'Dice1_val', 'Dice2_val', 'Dice3_val', 'HD1_val', 'HD2_val', 'HD3_val'] 
    metrics_dicts[fold_name] = fetch_wandb_metrics(entity, project, run_id, horizontal_axis, vertical_axes, show_plots = False)

#### This function returns the values of all metrics at a specified epoch, optionally saving to an Excel

In [49]:
import pandas as pd
import os

def get_metrics_at_epoch(metrics_dicts, target_epoch, excel_save_dir=None, save_to_excel=True):
    data_for_excel = []

    # Loop through each fold key in the metrics dictionary
    for key, metrics_dfs in metrics_dicts.items():
        
        # Initialize a dictionary to store metric values for this fold at the target epoch
        metric_data = {'Fold': key, 'Epoch': target_epoch}
        
        # Loop through each metric DataFrame and find the values at the target epoch
        for metric_name, df_metric in metrics_dfs.items():
            df_metric_filtered = df_metric[df_metric['epoch'] == target_epoch]
            if not df_metric_filtered.empty:
                metric_row = df_metric_filtered.iloc[0]
                metric_data[metric_name] = metric_row[metric_name]
            else:
                metric_data[metric_name] = None  # If target_epoch is not found in this metric
        
        # Append the collected data for this fold to the list
        data_for_excel.append(metric_data)

    # Convert the list of dictionaries to a DataFrame
    df_to_save = pd.DataFrame(data_for_excel)

    # Calculate the mean of each metric column, excluding non-numeric columns like 'Fold' and 'Epoch'
    average_metrics = df_to_save.drop(columns=['Fold', 'Epoch']).mean().to_dict()
    average_metrics['Fold'] = 'Average'
    average_metrics['Epoch'] = target_epoch  # Keep the target epoch for reference

    # Convert the averages dictionary to a DataFrame and concatenate it with the original DataFrame
    average_df = pd.DataFrame([average_metrics])
    df_to_save = pd.concat([df_to_save, average_df], ignore_index=True)

    # Define default file path if none is provided
    if not excel_save_dir:
        excel_save_dir = os.getcwd()

    excel_file_path = os.path.join(excel_save_dir, f'Metrics_At_Epoch_{target_epoch}.xlsx')

    if save_to_excel:
        df_to_save.to_excel(excel_file_path, index=False)
        print(f"Data saved to Excel at: {excel_file_path}")
    return df_to_save

#### This function will find the epoch where the metric **'max_metric_for_epoch'** is maximized, and return the value of all metrics at this epoch, optionally saving to an Excel.

In [50]:
import pandas as pd

def get_metrics_at_max_metric_epoch_for_all_folds(metrics_dicts, epoch_cutoff=100, max_metric_for_epoch='mean_Dice_val', excel_save_dir=None, save_to_excel = True):
    data_for_excel = []

    # Loop through each fold key in the metrics dictionary
    for key, metrics_dfs in metrics_dicts.items():
        
        # Check if max_metric_for_epoch is present in the DataFrame dictionary
        if max_metric_for_epoch in metrics_dfs:
            # Filter to only epochs within the cutoff and get the max epoch
            df_dice = metrics_dfs[max_metric_for_epoch][metrics_dfs[max_metric_for_epoch]['epoch'] <= epoch_cutoff]
            max_dice_index = df_dice[max_metric_for_epoch].idxmax()
            max_dice_epoch = df_dice.loc[max_dice_index, 'epoch']

            # Use get_metrics_at_epoch to retrieve all metrics at max_dice_epoch for this fold
            metrics_at_epoch_df = get_metrics_at_epoch({key: metrics_dfs}, target_epoch=max_dice_epoch, save_to_excel = False)
            metrics_data = metrics_at_epoch_df.iloc[0].to_dict()  # Convert row to dictionary

            # Set Fold and Max_Epoch for clarity
            metrics_data['Fold'] = key
            metrics_data['Max_Epoch'] = max_dice_epoch
            data_for_excel.append(metrics_data)

    # Convert the list of dictionaries to a DataFrame
    df_to_save = pd.DataFrame(data_for_excel)

    # Calculate the mean of each metric column, excluding non-numeric columns like 'Fold' and 'Max_Epoch'
    average_metrics = df_to_save.drop(columns=['Fold', 'Max_Epoch']).mean().to_dict()
    average_metrics['Fold'] = 'Average'
    average_metrics['Max_Epoch'] = 'N/A'  # Since averaging epochs doesn't make sense

    # Convert the averages dictionary to a DataFrame and concatenate it with the original DataFrame
    average_df = pd.DataFrame([average_metrics])
    df_to_save = pd.concat([df_to_save, average_df], ignore_index=True)


    # Define default file path if none is provided
    if not excel_save_dir:
        excel_save_dir = os.getcwd()

    excel_file_path = os.path.join(excel_save_dir, f'Metrics_At_Max_{max_metric_for_epoch}_Epoch_Upto_{epoch_cutoff}.xlsx')

    if save_to_excel:
        df_to_save.to_excel(excel_file_path, index=False)

    print(f"Data saved to Excel at: {excel_file_path}")
    return df_to_save

#### This function returns the max or min value of all metrics, and saves the results to an Excel.

In [51]:
import pandas as pd

def get_max_min_metrics(metrics_dicts, max_names, max_metrics, min_names, min_metrics, epoch_cutoff=75, excel_save_dir='None', save_to_excel = True):
    data = []

    # Process each fold in the metrics dictionary
    for key in metrics_dicts.keys():
        # Process metrics that are to be maximized
        for idx, metric in enumerate(max_metrics):
            df = metrics_dicts[key][metric]
            df_limited = df[df['epoch'] < epoch_cutoff]
            max_index = df_limited[metric].idxmax()
            max_row = df_limited.loc[max_index]
            max_value = max_row[metric]
            max_epoch = max_row['epoch']
            data.append({
                'Fold': key,
                'Metric Name': max_names[idx],
                'Value': max_value,
                'Epoch': int(max_epoch),
                'Type': 'Max'
            })

        # Process metrics that are to be minimized
        for idx, metric in enumerate(min_metrics):
            df = metrics_dicts[key][metric]
            df_limited = df[df['epoch'] < epoch_cutoff]
            min_index = df_limited[metric].idxmin()
            min_row = df_limited.loc[min_index]
            min_value = min_row[metric]
            min_epoch = min_row['epoch']
            data.append({
                'Fold': key,
                'Metric Name': min_names[idx],
                'Value': min_value,
                'Epoch': int(min_epoch),
                'Type': 'Min'
            })

    # Create a DataFrame from the collected data
    results_df = pd.DataFrame(data)

    # Set multi-level index if desired
    results_df.set_index(['Fold', 'Metric Name', 'Type'], inplace=True)

    # Define default file path if none is provided
    if not excel_save_dir:
        excel_save_dir = os.getcwd()

    excel_file_path = os.path.join(excel_save_dir, f'max_min_metrics.xlsx')

    if save_to_excel:
        results_df.to_excel(excel_file_path, index=False)

    print(f"Data saved to Excel at: {excel_file_path}")

    return results_df

In [53]:
excel_save_dir = '/Users/juampablo/Desktop/Kurtlab_A24/isbi_results'

# Epoch to fetch metrics from
target_epoch = 99
get_metrics_at_epoch(metrics_dicts=metrics_dicts, excel_save_dir = excel_save_dir, target_epoch=target_epoch)

# # Name of metric to maximize
# max_metric_for_epoch='mean_Dice_val'
# get_metrics_at_max_metric_epoch_for_all_folds(metrics_dicts, epoch_cutoff=100, max_metric_for_epoch=max_metric_for_epoch, excel_save_dir=excel_save_dir)


# # Names of metrics to maximize and minimize
# max_names = ['max_dice', 'max_dice1', 'max_dice2', 'max_dice3']
# max_metrics = ['mean_Dice_val', 'Dice1_val', 'Dice2_val', 'Dice3_val']
# min_names = ['min_hd', 'min_hd1', 'min_hd2', 'min_hd3']
# min_metrics = ['mean_HD_val', 'HD1_val', 'HD2_val', 'HD3_val']  

# get_max_min_metrics(metrics_dicts, max_names, max_metrics, min_names, min_metrics, epoch_cutoff=100, excel_save_dir=excel_save_dir)

Data saved to Excel at: /Users/juampablo/Desktop/Kurtlab_A24/isbi_results/Metrics_At_Epoch_99.xlsx


Unnamed: 0,Fold,Epoch,mean_Dice_val,mean_HD_val,Dice1_val,Dice2_val,Dice3_val,HD1_val,HD2_val,HD3_val
0,fold0,99,0.853002,8.977049,0.867547,0.855051,0.836408,9.963797,8.738432,8.228903
1,fold1,99,0.854783,8.751613,0.865037,0.861902,0.837409,10.220544,8.260969,7.773327
2,fold2,99,0.855373,9.96665,0.867001,0.858318,0.840801,11.766107,9.484021,8.649774
3,fold3,99,0.839136,10.862476,0.861933,0.84471,0.810763,12.06313,10.689711,9.834559
4,fold4,99,0.839929,8.712781,0.855483,0.841862,0.822441,10.358562,8.746573,7.033202
5,Average,99,0.848444,9.454114,0.8634,0.852369,0.829564,10.874428,9.183941,8.303953
