# Evaluation

In this document, we will analyse our models performance on the training data.

We will test the effect of three different parameters:

- Number of orientations (1, 2 or 3)
- Prediction accumulation function (mean or mean > 0.5)
- Peak detection function (center of mass or max filter)

Let's first look at the help text for how to use the python script for inference

In [None]:
# !python main.py infer -h

Now let's run all model combinations using a handy bash script at scripts/infer_with_all_models.sh

In [None]:
# !bash scripts/infer_with_all_models.sh

In [None]:
# !bash scripts/infer_with_all_models.sh

In [None]:
# !bash scripts/infer_with_all_models.sh

### Compare their performance against the labelled data

In [None]:
output_folder_prefix = './'

# config_path = './configs/fiddlercrab_corneas.yaml'
# y_path = './dataset/fiddlercrab_corneas/whole/test_labels_10/dampieri_male_16-corneas.csv'
# y_hat_path = 'output/dampieri_male_16-image.logs_fiddlercrab_corneas_lightning_logs_version_26_checkpoints_last_x_3_y_3_z_3_average_threshold_0.5_prediction_peak_min_val_0_5_method_center_of_mass.resampled_space_peaks.csv'
# mct_path = './dataset/fiddlercrab_corneas/whole/test_images_10/dampieri_male_16-image.nii'

config_path = './configs/paraphronima_corneas_without_random_rotation.yaml'
y_path = './dataset/paraphronima_corneas/whole/test_labels_10//P_crassipes_FEG190213_003b_02_head-corneas.csv'
y_hat_path = './output/P_crassipes_FEG190213_003b_02_head-image.logs_paraphronima_corneas_without_random_scale_lightning_logs_version_3_checkpoints_last_x_3_y_3_z_3_average_threshold_0.5_prediction_peak_min_val_0_25_method_center_of_mass.resampled_space_peaks.csv'
mct_path = './dataset/paraphronima_corneas/whole/test_images_10//P_crassipes_FEG190213_003b_02_head-image.nii'

In [None]:
from deep_radiologist.lightning_modules import Model
import yaml
from yaml.loader import SafeLoader
import numpy as np
import torchio as tio
from itables import init_notebook_mode

init_notebook_mode(all_interactive=True)



# load the coordinates to analyse
def load_coordinates(path, flip_axes=False, mct_path=None):
    locations = np.loadtxt(
        path,
        delimiter=',',
        ndmin=2,
        dtype=np.float64
    )
    if not flip_axes: 
        return locations.tolist()

    if not mct_path:
        raise Exception('You must specify `mct_path` if you need to flip_axes')

    mct = tio.ScalarImage(mct_path)
    locations[:,0] = mct.shape[1] - locations[:,0]
    locations[:,1] = mct.shape[2] - locations[:,1]

    return locations.tolist()


y_hat = load_coordinates(y_hat_path, flip_axes=True, mct_path=mct_path)
y = load_coordinates(y_path)

Let's plot the y vs y_hat

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go

def plot(*datasets, colors, labels=None, backend='matplotlib'):
    if backend == 'matplotlib':
        # Create a 3D plot using Matplotlib
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        # Iterate over datasets and corresponding colors
        for i, (data, color) in enumerate(zip(datasets, colors)):
            # Unpack the data into x, y, z coordinates
            x = [point[0] for point in data]
            y = [point[1] for point in data]
            z = [point[2] for point in data]
            
            # Plot each dataset
            ax.scatter(x, y, z, color=color, label=f'dataset_{i+1}' if labels is None else labels[i])
        
        # Label the axes
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # Add a legend
        ax.legend()
        
        # Show the plot
        plt.show()
        return fig
    elif backend == 'plotly':
        # Create a 3D plot using Plotly
        fig = go.Figure()

        # Iterate over datasets and corresponding colors
        for i, (data, color) in enumerate(zip(datasets, colors)):
            # Unpack the data into x, y, z coordinates
            x = [point[0] for point in data]
            y = [point[1] for point in data]
            z = [point[2] for point in data]
            
            # Plot each dataset
            fig.add_trace(go.Scatter3d(
                x=x, y=y, z=z,
                mode='markers',
                marker=dict(color=color, size=2),
                name=f'dataset_{i+1}' if labels is None else labels[i]
            ))

        # Label the axes
        fig.update_layout(
            scene=dict(
                xaxis_title='X',
                yaxis_title='Y',
                zaxis_title='Z'
            )
        )

        # Show the plot
        fig.show()
        return fig

In [None]:
def evaluate(config_path, y, y_hat):
    # get the accuracy metrics for this image
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=SafeLoader)

    # override correct_prediction_distance to something more appropriate
    config["correct_prediction_distance"] = 20
    
    model = Model(config)
    
    tp, fp, fn, loc_err = model._get_acc_metrics(y_hat, y)
    tps, fps, fns = model._get_acc_metrics(y_hat, y, return_coords=True)
    
    return tp, fp, fn, loc_err, tps, fps, fns

tp, fp, fn, loc_err, tps, fps, fns = evaluate(config_path, y, y_hat)

In [None]:
print('True positives:', tp)
print('False positives:', fp)
print('False negatives:', fn)

In [None]:
plot(
    tps,
    fps,
    fns,
    colors=['green', 'blue', 'red'],
    labels=['True positive', 'False positives', 'False negatives'],
    backend='plotly'
)

That looks like a great result. Now let's run it for the rest of the scans/models

## Evaluate all models

In [None]:
import os
from pprint import pprint
import pickle
import pandas as pd
import yaml


# folder_path = './output/'
folder_path = './output/'


substrings = [
  "fiddlercrab_corneas_lightning_logs_version_26",
  "fiddlercrab_rhabdoms_lightning_logs_version_10",
  "paraphronima_corneas_lightning_logs_version_11",
  "paraphronima_rhabdoms_lightning_logs_version_23",
  "paraphronima_rhabdoms_without_target_heatmap_masking_lightning_logs_version_1",
  "fiddlercrab_rhabdoms_with_target_heatmap_masking_lightning_logs_version_2",
  "paraphronima_rhabdoms_without_elastic_deformation_lightning_logs_version_2",
  "paraphronima_corneas_without_elastic_deformation_lightning_logs_version_2",
  "fiddlercrab_corneas_without_elastic_deformation_lightning_logs_version_2",
  "fiddlercrab_rhabdoms_without_elastic_deformation_lightning_logs_version_3",
  "paraphronima_rhabdoms_without_random_scale_lightning_logs_version_3",
  "paraphronima_corneas_without_random_scale_lightning_logs_version_3",
  "fiddlercrab_rhabdoms_without_random_scale_lightning_logs_version_3",
  "paraphronima_rhabdoms_without_random_rotation_lightning_logs_version_2",
  "paraphronima_corneas_without_random_rotation_lightning_logs_version_1",
  "fiddlercrab_corneas_without_random_rotation_lightning_logs_version_2",
  "fiddlercrab_rhabdoms_without_random_rotation_lightning_logs_version_0",
  "fiddlercrab_corneas_hist_std_lightning_logs_version_0",
  "fiddlercrab_rhabdoms_hist_std_lightning_logs_version_0",
  "fiddlercrab_corneas_random_affine_prop_1_lightning_logs_version_0",
  "fiddlercrab_corneas_hist_std_z_norm_lightning_logs_version_0",
  "fiddlercrab_rhabdoms_hist_std_z_norm_lightning_logs_version_0"
]

for substring in substrings:

    def find_files_with_substring_and_suffix(folder_path, substring, suffix, other_words):
        matching_files = []
        for filename in os.listdir(folder_path):
            if substring in filename and filename.endswith(suffix) and other_words in filename:
                matching_files.append(filename)
        return matching_files

    def split_filenames_by_hyphen(filenames):
        split_filenames = [filename.split('-')[0] for filename in filenames]
        return split_filenames

    suffix = 'resampled_space_peaks.csv'
    files = find_files_with_substring_and_suffix(folder_path, substring, suffix, 'x_3_y_3_z_3')
    split_files = split_filenames_by_hyphen(files)

    print(files[:3])
    print(split_files[:3])
    print(len(files))

    prefix = substring.split('_lightning')[0]
    feature_to_find = 'rhabdoms' if 'rhabdoms' in prefix else 'corneas'

    config_file = f'./configs/{prefix}.yaml'

    # read the config
    with open(config_file, 'r') as file:
        config = yaml.safe_load(file)
    
    config_paths = [config_file for file in files]
    y_paths = [f'{config["test_labels_dir"]}/{name}-{config["label_suffix"]}.csv'.replace("patches", "whole") for name in split_files]
    y_hat_paths = [f'./output/{file}' for file in files]
    mct_paths = [f'{config["test_images_dir"]}/{name}-image.nii'.replace("patches", "whole") for name in split_files]

    print(config_paths[:5])
    print(y_paths[:5])
    print(y_hat_paths[:5])
    print(mct_paths[:5])
    len(mct_paths)

    plot_results = False

    for config_path, y_path, y_hat_path, mct_path in zip(config_paths, y_paths, y_hat_paths, mct_paths):
        results = []

        y_hat = load_coordinates(y_hat_path, flip_axes=True, mct_path=mct_path)
        y = load_coordinates(y_path)

        tp, fp, fn, loc_err, tps, fps, fns = evaluate(config_path, y, y_hat)

        print('True positives:', tp)
        print('False positives:', fp)
        print('False negatives:', fn)
        
        if plot_results:
            plot(
                tps,
                fps,
                fns,
                colors=['green', 'blue', 'red'],
                labels=['True positive', 'False positives', 'False negatives'],
                backend='plotly'
            )

        results.append(
            {
                'mct_path': mct_path,
                'y_path': y_path,
                'y_hat_path': y_hat_path,
                'config_path': config_path,
                'num_tps': tp,
                'num_fps': fp,
                'num_fns': fn,
                'tps': tps,
                'fps': fps,
                'fns': fns
            }
        )
        
        file_name_without_extension = os.path.splitext(os.path.basename(y_hat_path))[0]
        # shorten if larger than 30 characters
        if len(file_name_without_extension) > 30:
            file_name_without_extension = file_name_without_extension[:30]
    
        print(f'saving pickle to ./analysis_output/{file_name_without_extension}_results.pickle')
        with open(f'./analysis_output/{file_name_without_extension}_results.pickle', 'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
#     if 'num_tps' not in results:
#         print('couldnt find results')
#         continue
        # specify the keys you want to keep
        keys_to_keep = ['mct_path', 'config_path', 'y_path', 'y_hat_path', 'num_tps', 'num_fps', 'num_fns']

        # create a new list of dictionaries with only the specified keys
        filtered_data = [{k: v for k, v in d.items() if k in keys_to_keep} for d in results]

        df = pd.DataFrame(filtered_data)


        # Calculate precision
        precision = df['num_tps'] / (df['num_tps'] + df['num_fps'])

        # Calculate recall
        recall = df['num_tps'] / (df['num_tps'] + df['num_fns'])

        # Calculate F1 score
        f1_score = 2 * (precision * recall) / (precision + recall)
        f1_score

        # add f1 score
        df['precision'] = precision
        df['recall'] = recall
        df['f1'] = f1_score

        df.to_csv(f'./analysis_output/{file_name_without_extension}.csv')


In [None]:
import pandas as pd
from os import listdir
import os
import glob

all_files = glob.glob("./analysis_output/*.csv")

df = pd.concat((pd.read_csv(f) for f in all_files), ignore_index=True)
pickles = glob.glob("./analysis_output/*.pickle")

# create a dictionary mapping basenames of pickle files to their full path
pickle_map = {'.'.join(os.path.basename(pickle).rsplit('.', 2)[:-2]): pickle for pickle in pickles}

def link_pickle_to_y_hat(y_hat_path):
    ''' Extract the basename from the 'y_hat_path' and use it to find the corresponding pickle file '''
    y_hat_basename = os.path.basename(y_hat_path)
    y_hat_basename = '.'.join(y_hat_basename.rsplit('.', 2)[:-2])
    return pickle_map.get(y_hat_basename, None)  # Returns None if no match is found

# Apply the function to the 'y_hat_path' column to populate the 'pickle' column
df['pickle'] = df['y_hat_path'].apply(link_pickle_to_y_hat)

df

In [None]:
# Exclude groups where all 'f1' values are NaN
valid_groups = df.groupby('config_path')['f1'].transform('max').notna()
df = df[valid_groups]

In [None]:
# Group by 'config_path' and get the index of the row with the max 'f1' score for each group
idx = df.groupby(['config_path', 'mct_path'])['f1'].idxmax()

# Sort the filtered DataFrame by 'recall' in descending order (highest first)
best_df = df.sort_values(by='f1', ascending=False)

# filter the best_df to keep only the best models for each config_path and scan
best_df = best_df.groupby(['config_path', 'mct_path']).head(1)

best_df.to_csv('best_models.csv')

In [None]:
# Display the filtered DataFrame
print(best_df)

this shows:
- max_filter and center_of_mass are good in different circumstances
- more orientations is better (>5) but let's do 3, 3, 3
- avg_threshold is better smaller (0.2) for flammula but greater (0.5) for dampieri
- peak_min_val is better smaller (0.2) for flammula by greater (0.5) for dampieri

Let's find the best parameters for inference for both dampieri and flammula

The hyperiid rhabdom model appears to have a lot of false positives in predicted areas not around rhabdoms. This is liekly because the training data only included areas of scans around rhabdoms. So we can judge the performance of the model under these same circumstances, we need to only evaluate in areas of scans around rhabdoms. We do this below

- Load paraphronima rhabdom inference
- add filter to only evaluate area around ground truth labels
  

Note, if evaluating the model in this way, we need to ensure readers know that detecting elongated features that are partially labelled (with single points) can be done by using our masking approach. However, the masking approach can lead to false positives in training data - by marking all high or low values around the point label, making the resulting model sensitive to false positives. It is therefore recommendeded to crop scans around the area of interest before running inference. In this case, combining masking with a cropped scan for inference, can provide great results on partially labelled data.