In [1]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import SimpleITK as sitk
import torch
import monai
import pickle
from monai.metrics import DiceMetric, SurfaceDistanceMetric
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureChannelFirst,
    EnsureType,
    Lambda,
    LoadImage,
    Orientation,
    ToDevice,
    Transpose,
    LabelFilterd,
    MapLabelValued
    
    
)

from tqdm import tqdm
from lighter.utils.dynamic_imports import import_module_from_path
from pathlib import Path
from tqdm import tqdm
import torchmetrics
from totalsegmentator.map_to_binary import class_map

import_module_from_path("project", "/home/suraj/Repositories/lighter-ct-fm")
from project.data import get_ts_class_indices, get_ts_class_labels

In [2]:
label_map = class_map["total"]
pred_dir = Path("/mnt/data1/CT_FM/evaluations/totalseg/predictions")
dataset_path = Path("/mnt/data1/TotalSegmentator/v2/processed")

In [3]:
results = []
device = "cuda:0"
# group_filter = [
#     "few-shot_100", "few-shot_50", "few-shot_20", "few-shot_10", "few-shot_5",
# ]
group_filter = [
    "vista_v2",
]

In [16]:


for model_dir in pred_dir.glob("*"):
    group = "_".join(model_dir.name.split("_")[-2:])
    if group not in group_filter:
        continue

    group_map = {
        "quick_v2": "v2",
        "merlin_V2": "merlin_v2",
        "vista_v2": "v2",
        "fulltune_v2": "v2",
        "few-shot_100": "v2",
        "few-shot_50": "v2",
        "few-shot_20": "v2",
        "few-shot_10": "v2",
        "few-shot_5": "v2",
    }

    model_map = {
        "ct_fm": "CT FM (Ours)",
        "baseline": "Random Init.",
        "suprem_unet": "SuPREM"
    }

    if group in group_map:
        data_group = group_map[group]
    else:
        data_group = group
    
    model_name = "_".join(model_dir.name.split("_")[:-2])

    if model_name in model_map:
        model_name = model_map[model_name]

    print(f"Evaluating... Group: {group}, Model: {model_name}\n")
    class_indices = get_ts_class_indices(group=data_group)
    class_labels = get_ts_class_labels(class_indices, group=data_group)
    out_channels = len(class_indices)

    base_transforms = Compose([
        LoadImage(),
        ToDevice(device=device),
        EnsureChannelFirst(),
        EnsureType(data_type="tensor", dtype="int"),
        Orientation(axcodes="SPL" if "suprem" not in model_name else "RAS"),
    ])

    mapping_transforms = Compose([
        Lambda(lambda x: {"label": x}),
        LabelFilterd(keys="label", applied_labels=class_indices),
        MapLabelValued(keys="label", orig_labels=class_indices, target_labels=list(range(out_channels))),
        Lambda(lambda x: x["label"])
    ])

    target_transforms = Compose([base_transforms, mapping_transforms])

    print("Calculating Dice Scores... \n")
    dice_dict = {label: [] for label in class_labels}
    overall_dice_list = []
    image_samples = []
    preds_list = list(model_dir.glob("*"))
    for pred_path in tqdm(preds_list):
        sid = pred_path.stem.split(".")[0]
        label = target_transforms(dataset_path / sid / "label.nii.gz").unsqueeze(0)
        image = base_transforms(dataset_path / sid / "ct.nii.gz").unsqueeze(0)
        pred = base_transforms(pred_path).unsqueeze(0)

        res = monai.metrics.compute_dice(pred, label, num_classes=out_channels).squeeze().tolist()
        agg_res = monai.metrics.compute_generalized_dice(pred, label, include_background=False, weight_type='uniform').item()
        overall_dice_list.append(agg_res)
        for label_name, score in zip(class_labels, res):
            dice_dict[label_name].append(score)
        
        if len(image_samples) < 5:
            image_samples.append({
                "sid": sid,
                "label": label,
                "image": image,
                "pred": pred
            })

    dice_dict.pop("background", None)
    class_aggregate_dict = {k: np.nanmean(v) for k, v in dice_dict.items()}
    class_aggregate_dice = np.nanmean(list(class_aggregate_dict.values()))
    overall_dice = np.nanmean(overall_dice_list)
    print(overall_dice, "\n")
    results.append({"group": group, "model": model_name, "all_scores": dice_dict, "class_dice_scores": dice_dict, "overall_dice": overall_dice, "class_aggregate_dice": class_aggregate_dice, "image_samples": image_samples})


Evaluating... Group: vista_v2, Model: CT FM (Ours)

Number of classes: 118
Calculating Dice Scores... 



  0%|          | 0/248 [00:00<?, ?it/s]

100%|██████████| 248/248 [04:23<00:00,  1.06s/it]


66.48452369628414 

Evaluating... Group: vista_v2, Model: Random Init.

Number of classes: 118
Calculating Dice Scores... 



100%|██████████| 248/248 [04:06<00:00,  1.01it/s]


66.38507990683279 

Evaluating... Group: vista_v2, Model: SuPREM

Number of classes: 118
Calculating Dice Scores... 



100%|██████████| 248/248 [04:06<00:00,  1.01it/s]

64.71189448141283 






In [4]:

# # Save results to a pkl file
import pickle
# with open('artifacts/vista_totalseg.pkl', 'wb') as f:
#     pickle.dump(results, f)

# # Load results from the pkl file
with open('artifacts/vista_totalseg.pkl', 'rb') as f:
    results = pickle.load(f)

In [6]:
import numpy as np
from scipy import stats

for idx, result in enumerate(results):
    # Extract the class dice scores
    class_dice_scores = result['class_dice_scores']

    # Calculate the mean of all organ scores
    mean_score = np.nanmean([np.nanmean(scores) for scores in class_dice_scores.values()])

    # Calculate the standard error
    all_scores = [score for scores in class_dice_scores.values() for score in scores if not np.isnan(score)]
    se = stats.sem(all_scores)

    # Calculate 95% CI
    ci_95 = stats.t.interval(confidence=0.95, df=len(all_scores)-1, loc=mean_score, scale=se)

    # Add the calculated statistics to the result dictionary
    results[idx]['confidence_interval'] = ci_95

    print(f"Model: {result['model']}, Group: {result['group']}")
    print(f"Overall Mean Dice Score: {mean_score:.4f}")
    print(f"95% Confidence Interval: ({ci_95[0]:.4f}, {ci_95[1]:.4f})")
    print("---")

Model: CT FM (Ours), Group: vista_v2
Overall Mean Dice Score: 0.8981
95% Confidence Interval: (0.8959, 0.9004)
---
Model: Random Init., Group: vista_v2
Overall Mean Dice Score: 0.8959
95% Confidence Interval: (0.8936, 0.8982)
---
Model: SuPREM, Group: vista_v2
Overall Mean Dice Score: 0.8695
95% Confidence Interval: (0.8668, 0.8721)
---


In [19]:

# import seaborn as sns
# # Convert the data into a pandas DataFrame
# rows = []
# for entry in results:
#     for organ, dice_scores in entry['all_scores'].items():
#         rows.append({
#             'Model': entry['model'],
#             'Group': entry['group'],
#             'Organ': organ,
#             'Dice Score': np.nanmean(dice_scores),
#             'Overall Dice': entry['overall_dice'],
#             'Macro Dice': entry['class_aggregate_dice']
#         })

# additional_entries = {
#     "vista_v2": [
#         {"Group": "vista_v2", "Model": "AutoSeg3D*", "Macro Dice": 0.882},
#         {"Group": "vista_v2", "Model": "nnUnet*", "Macro Dice": 0.906},
#         {"Group": "vista_v2", "Model": "VISTA3D Auto*", "Macro Dice": 0.893}
#     ],
#     "merlin_V2":  [{"Group": "merlin_V2", "Model": "Merlin FM", "Macro Dice": 0.86}]
# }

# model_order = ["CT FM (Ours)", "Random Init.", "SuPREM"]
# # model_order = ["CT FM (Ours)", "Random Init."]

# for group in group_filter:
#     if group in additional_entries:
#         rows.extend(additional_entries[group])
#         model_order.extend([entry["Model"] for entry in additional_entries[group]])


# print(model_order)

# df = pd.DataFrame(rows)
# df['Model'] = pd.Categorical(df['Model'], categories=model_order, ordered=True)
# df['Group'] = df['Group'].replace({
#     'few-shot_5': '5',
#     'few-shot_10': '10',
#     'few-shot_20': '20',
#     'few-shot_50': '50',
#     'few-shot_100': '100'
# })
# df['Group'] = pd.Categorical(df['Group'], categories=['5', '10', '20', '50', '100'], ordered=True)
# df = df.sort_values(by=["Model", "Group"])
# # df = df.sort_values(by=["Model"])

# font_size = 30
# gray_palette = sns.color_palette("Blues_r", 6).as_hex()
# color_list = gray_palette
    
# # Overall Dice Score comparison
# fig = px.line(
#     df[["Model", "Overall Dice", "Macro Dice", "Group"]].drop_duplicates(),
#     x='Group',
#     y='Macro Dice',
#     color='Model',
#     title='',
#     height=800,
#     width=800,
#     template='plotly_white',
#     color_discrete_sequence=color_list,
#     markers=True,
# )
# # fig = px.bar(
# #     df[["Model", "Overall Dice", "Macro Dice", "Group"]].drop_duplicates(),
# #     x='Group',
# #     y='Macro Dice',
# #     color='Model',
# #     title='',
# #     height=600,
# #     width=400,
# #     template='plotly_white',
# #     barmode="group",
# #     range_y=[0.8, 0.92],
# #     color_discrete_sequence=color_list,
# # )

# # Update layout for a more minimalist look
# fig.update_layout(
#     plot_bgcolor='white',
#     paper_bgcolor='white',
#     font=dict(color='black', size=font_size),
#     title=dict(font=dict(size=font_size)),
    
#     xaxis=dict(
#         showline=True,
#         linewidth=1,
#         linecolor='black',
#         mirror=False
#     ),
#     yaxis=dict(
#         title='Dice score',
#         showline=True,
#         linewidth=1,
#         linecolor='black',
#         mirror=False
#     ),
#     legend=dict(
#         orientation="h",
#         yanchor="bottom",
#         y=1.02,
#         xanchor="right",
#         x=1
#     )
# )

# # Remove gridlines
# fig.update_xaxes(showgrid=False)
# fig.update_yaxes(showgrid=True)
# # fig.update_traces(marker=dict(pattern=dict(shape="\\")), selector=dict(name="AutoSeg3D*"))
# # fig.update_traces(marker=dict(pattern=dict(shape="\\")), selector=dict(name="nnUnet*"))
# # fig.update_traces(marker=dict(pattern=dict(shape="\\")), selector=dict(name="VISTA3D Auto*"))

# # Update marker style
# fig.update_traces(
#     marker=dict(size=25, line=dict(width=2, color='black')),
#     line=dict(width=5)
# )

# fig.show()


In [30]:
import seaborn as sns
# Convert the data into a pandas DataFrame
rows = []
for entry in results:
    for organ, dice_scores in entry['all_scores'].items():
        rows.append({
            'Model': entry['model'],
            'Group': entry['group'],
            'Organ': organ,
            'Dice Score': np.nanmean(dice_scores),
            'Overall Dice': entry['overall_dice'],
            'Macro Dice': entry['class_aggregate_dice'],
            'Upper CI': entry['confidence_interval'][1] - entry['class_aggregate_dice'],
            'Lower CI': entry['class_aggregate_dice'] - entry['confidence_interval'][0]
        })

additional_entries = {
    "vista_v2": [
        {"Group": "vista_v2", "Model": "AutoSeg3D*", "Macro Dice": 0.882, "Upper CI": 0, "Lower CI": 0},
        # {"Group": "vista_v2", "Model": "nnUnet*", "Macro Dice": 0.906, "Upper CI": 0, "Lower CI": 0},
        {"Group": "vista_v2", "Model": "VISTA3D Auto*", "Macro Dice": 0.893, "Upper CI": 0, "Lower CI": 0}
    ],
    "merlin_V2":  [{"Group": "merlin_V2", "Model": "Merlin FM", "Macro Dice": 0.86, "Upper CI": 0, "Lower CI": 0}]
}

model_order = ["CT FM (Ours)", "Random Init.", "SuPREM"]
# model_order = ["CT FM (Ours)", "Random Init."]

for group in group_filter:
    if group in additional_entries:
        rows.extend(additional_entries[group])
        model_order.extend([entry["Model"] for entry in additional_entries[group]])


print(model_order)

df = pd.DataFrame(rows)
df['Model'] = pd.Categorical(df['Model'], categories=model_order, ordered=True)
df = df.sort_values(by=["Model"])

font_size = 26
gray_palette = sns.color_palette("Blues_r", 6).as_hex()
color_list = gray_palette
    
fig = px.bar(
    df[["Model", "Overall Dice", "Macro Dice", "Group", "Upper CI", "Lower CI"]].drop_duplicates(),
    x='Group',
    y='Macro Dice',
    color='Model',
    title='',
    height=800,
    width=600,
    template='plotly_white',
    barmode="group",
    range_y=[0.8, 0.92],
    color_discrete_sequence=color_list,
    error_y='Upper CI',
    error_y_minus='Lower CI'
)

# Update layout for a more minimalist look
fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black', size=font_size),
    title=dict(font=dict(size=font_size)),
    
    xaxis=dict(
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=False
    ),
    yaxis=dict(
        title='Dice score',
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=False
    ),
    legend=dict(
        # orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1,
        # itemsizing='constant',
        # traceorder='normal',
        # tracegroupgap=1,
        # itemclick=False,
        # itemdoubleclick=False
    )
)

# Remove gridlines
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=True)

# Limit the legend to two items by updating traces
fig.update_traces(marker=dict(pattern=dict(shape="\\")), selector=dict(name="AutoSeg3D*"))
fig.update_traces(marker=dict(pattern=dict(shape="\\")), selector=dict(name="nnUnet*"))

# Add error bars
# fig.update_traces(error_y=dict(type='data', symmetric=False, array=df['Upper CI'].drop_duplicates() - df['Macro Dice'].drop_duplicates(), arrayminus=df['Macro Dice'].drop_duplicates() - df['Lower CI'].drop_duplicates()))

fig.show()


['CT FM (Ours)', 'Random Init.', 'SuPREM', 'AutoSeg3D*', 'VISTA3D Auto*']


: 

In [21]:
from totalsegmentator.map_to_binary import class_map, class_map_5_parts

In [22]:
organ_df = df[["Model", "Organ", "Dice Score"]].copy()

# Get unique models and organs
models = organ_df["Model"].unique()
models = [model for model in models if model in ["CT FM (Ours)", "Random Init.", "SuPREM"]]
organs = organ_df["Organ"].unique()

# Initialize a dictionary to store win/loss counts
win_loss_counts = {model: {opponent: {"wins": 0, "losses": 0} for opponent in models if opponent != model} for model in models}

# Compare each pair of models for each organ
for organ in organs:
    organ_data = organ_df[organ_df["Organ"] == organ]
    for model in models:
        for opponent in models:
            if model != opponent:
                model_scores = organ_data[organ_data["Model"] == model]["Dice Score"].values
                opponent_scores = organ_data[organ_data["Model"] == opponent]["Dice Score"].values
                if len(model_scores) > 0 and len(opponent_scores) > 0:
                    model_score = model_scores[0]
                    opponent_score = opponent_scores[0]
                    if model_score > opponent_score:
                        win_loss_counts[model][opponent]["wins"] += 1
                    elif model_score < opponent_score:
                        win_loss_counts[model][opponent]["losses"] += 1

# Calculate win and loss percentages
win_loss_percentages = {model: {} for model in models}
for model in models:
    for opponent in win_loss_counts[model]:
        total_comparisons = win_loss_counts[model][opponent]["wins"] + win_loss_counts[model][opponent]["losses"]
        if total_comparisons > 0:
            win_percentage = win_loss_counts[model][opponent]["wins"] / total_comparisons * 100
            loss_percentage = win_loss_counts[model][opponent]["losses"] / total_comparisons * 100
            win_loss_percentages[model][opponent] = {"win": win_percentage, "loss": loss_percentage}

# Create a horizontal bar plot for each model
for model in models:
    opponents = list(win_loss_percentages[model].keys())
    win_percentages = [win_loss_percentages[model][opponent]["win"] for opponent in opponents]
    loss_percentages = [win_loss_percentages[model][opponent]["loss"] for opponent in opponents]
    
    fig = go.Figure()
    
    # Add win percentage bars
    fig.add_trace(go.Bar(
        y=opponents,
        x=win_percentages,
        name='Win',
        orientation='h',
        width=0.3,
        marker=dict(color=color_list[0]),
        text=[f'{x:.1f}%' for x in win_percentages],
        textposition='inside',
        insidetextfont=dict(color='white'),
    ))
    
    # Add loss percentage bars
    fig.add_trace(go.Bar(
        y=opponents,
        x=loss_percentages,
        name='Loss',
        orientation='h',
        width=0.3,
        marker=dict(color='lightgray'),
        text=[f'{x:.1f}%' for x in loss_percentages],
        textposition='inside',
        insidetextfont=dict(color='white'),
    ))
    
    fig.update_layout(
        title=f"Win/Loss Percentage for {model}",
        xaxis_title="Percentage",
        yaxis_title="Opponent",
        barmode='stack',
        height=600,
        width=600,
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(color='black', size=20),
        xaxis=dict(range=[0, 100], ticksuffix="%"),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        margin=dict(l=50, r=20, t=50, b=50)
    )
    
    fig.show()

In [23]:
df

Unnamed: 0,Model,Group,Organ,Dice Score,Overall Dice,Macro Dice,Upper CI,Lower CI
0,CT FM (Ours),vista_v2,spleen,0.975054,66.484524,0.898142,0.002286,0.002286
85,CT FM (Ours),vista_v2,autochthon_left,0.963524,66.484524,0.898142,0.002286,0.002286
84,CT FM (Ours),vista_v2,gluteus_minimus_right,0.935229,66.484524,0.898142,0.002286,0.002286
83,CT FM (Ours),vista_v2,gluteus_minimus_left,0.932531,66.484524,0.898142,0.002286,0.002286
82,CT FM (Ours),vista_v2,gluteus_medius_right,0.958140,66.484524,0.898142,0.002286,0.002286
...,...,...,...,...,...,...,...,...
265,SuPREM,vista_v2,vertebrae_T12,0.907227,64.711894,0.869470,0.002659,0.002659
277,SuPREM,vista_v2,vertebrae_C7,0.920082,64.711894,0.869470,0.002659,0.002659
351,AutoSeg3D*,vista_v2,,,,0.882000,0.000000,0.000000
352,nnUnet*,vista_v2,,,,0.906000,0.000000,0.000000


In [24]:
# Calculate the difference in Dice scores between CT-FM and Random Init.
# Filter out NaN values before calculating the difference
ct_fm_scores = organ_df[organ_df['Model'] == 'CT FM (Ours)'].set_index('Organ')['Dice Score']
random_init_scores = organ_df[organ_df['Model'] == 'Random Init.'].set_index('Organ')['Dice Score']

# Align the indices and drop NaN values
ct_fm_scores, random_init_scores = ct_fm_scores.align(random_init_scores, join='inner')
ct_fm_scores = ct_fm_scores.dropna()
random_init_scores = random_init_scores.dropna()

# Calculate the difference
diff_df = ct_fm_scores - random_init_scores

# Sort the differences to get top improved and worsened classes
top_improved = diff_df.nlargest(25)
top_worsened = diff_df.nsmallest(25)

# Display the results
print("Top 5 classes improved by CT-FM compared to Random Init:")
print(top_improved)
print("\nTop 5 classes worsened by CT-FM compared to Random Init:")
print(top_worsened)

# Create a bar plot to visualize these differences using plotly express
import plotly.express as px
import pandas as pd

# Prepare data for plotting
improved_df = pd.DataFrame({'Organ': top_improved.index, 'Difference': top_improved.values, 'Category': 'Improved'})
worsened_df = pd.DataFrame({'Organ': top_worsened.index, 'Difference': top_worsened.values, 'Category': 'Worsened'})
plot_df = pd.concat([improved_df, worsened_df])

# Create the plot
fig = px.bar(plot_df, x='Organ', y='Difference', color='Category', barmode='group',
             title='Top 5 Classes Improved and Worsened by CT-FM compared to Random Init',
             labels={'Difference': 'Dice Score Difference'},
             color_discrete_map={'Improved': 'green', 'Worsened': 'red'},
             height=600, width=800)

# Customize the layout
fig.update_layout(
    xaxis_title='Organ',
    yaxis_title='Dice Score Difference',
    legend_title='Category',
    font=dict(size=14),
    xaxis_tickangle=-45,
    plot_bgcolor='white',
)

# Show the plot
fig.show()

Top 5 classes improved by CT-FM compared to Random Init:
Organ
vertebrae_C5               0.022891
rib_left_6                 0.019483
vertebrae_T5               0.016581
brain                      0.014847
vertebrae_T6               0.013873
rib_left_5                 0.013633
adrenal_gland_left         0.012929
vertebrae_C4               0.012844
rib_left_7                 0.012651
rib_right_10               0.012231
vertebrae_C3               0.011453
vertebrae_T4               0.010968
rib_right_8                0.010850
rib_right_7                0.010785
kidney_cyst_right          0.010614
rib_right_6                0.010391
rib_left_8                 0.010302
vertebrae_T3               0.010189
prostate                   0.009262
vertebrae_T7               0.008926
subclavian_artery_left     0.007484
rib_right_5                0.007332
rib_left_11                0.007134
rib_left_12                0.006942
subclavian_artery_right    0.006770
Name: Dice Score, dtype: float64

Top

In [25]:
for label_group in class_map_5_parts:
    labels = list(class_map_5_parts[label_group].values())
    sub_df = df[df["Organ"].isin(labels)]

    color_list = gray_palette

    # Ensure the 'Group' column is ordered as 5, 10, 20, 50, 100
    sub_df['Group'] = pd.Categorical(sub_df['Group'], categories=['few-shot_5', 'few-shot_10', 'few-shot_20', 'few-shot_50', 'few-shot_100'], ordered=True)

    # Overall Dice Score comparison
    fig = px.line(
        sub_df.groupby(['Model', 'Group'])['Dice Score'].apply(lambda x: np.nanmean(x)).reset_index(),
        x='Group',
        y='Dice Score',
        color='Model',
        title=label_group,
        height=600,
        width=600,
        template='plotly_white',
        color_discrete_sequence=color_list,
        markers=True  # Add dots to markers
    )

    # Add black borders to the bars
    fig.update_traces(marker=dict(size=15, line=dict(width=1, color='black')),
    line=dict(width=3))
    fig.update_layout(
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(color='black', size=16)
    )
    # fig.update_traces(texttemplate='%{y:.3f}', textposition='outside')
    fig.show()


In [31]:
list(df["Organ"].unique())

['spleen',
 'autochthon_left',
 'gluteus_minimus_right',
 'gluteus_minimus_left',
 'gluteus_medius_right',
 'gluteus_medius_left',
 'gluteus_maximus_right',
 'gluteus_maximus_left',
 'spinal_cord',
 'hip_right',
 'hip_left',
 'femur_right',
 'femur_left',
 'clavicula_right',
 'clavicula_left',
 'scapula_right',
 'scapula_left',
 'humerus_right',
 'humerus_left',
 'iliac_vena_right',
 'iliac_vena_left',
 'iliac_artery_right',
 'iliac_artery_left',
 'portal_vein_and_splenic_vein',
 'inferior_vena_cava',
 'superior_vena_cava',
 'autochthon_right',
 'iliopsoas_right',
 'brain',
 'skull',
 'costal_cartilages',
 'sternum',
 'rib_right_12',
 'rib_right_11',
 'rib_right_10',
 'rib_right_9',
 'rib_right_8',
 'rib_right_7',
 'rib_right_6',
 'rib_right_5',
 'rib_right_4',
 'rib_right_3',
 'atrial_appendage_left',
 'rib_right_2',
 'rib_left_12',
 'rib_left_11',
 'rib_left_10',
 'rib_left_9',
 'rib_left_8',
 'rib_left_7',
 'rib_left_6',
 'rib_left_5',
 'rib_left_4',
 'rib_left_3',
 'rib_left_2',
 '

In [36]:
list(df["Model"].unique())

['CT FM (Ours)',
 'Random Init.',
 'SuPREM',
 'AutoSeg3D*',
 'nnUnet*',
 'VISTA3D Auto*']

In [64]:
labels = {
        "liver": ["liver"],
    "spleen": ["spleen"],
    "adrenal gland": ["adrenal_gland_left", "adrenal_gland_right"],
    "vertebrae":  ['vertebrae_C1',
                    'vertebrae_L2',
                    'vertebrae_C3',
                    'vertebrae_C4',
                    'vertebrae_C5',
                    'vertebrae_C6',
                    'vertebrae_C7',
                    'vertebrae_C2',
                    'vertebrae_T2',
                    'vertebrae_T1',
                    'vertebrae_L1',
                    'vertebrae_T12',
                    'vertebrae_T11',
                    'vertebrae_T9',
                    'vertebrae_T10',
                    'vertebrae_T7',
                    'vertebrae_T6',
                    'vertebrae_T5',
                    'vertebrae_T4',
                    'vertebrae_T3',
                    'vertebrae_T8'],
    "pancreas": ["pancreas"],
    "kidney": ["kidney_left", "kidney_right"],
    "gallbladder": ["gallbladder"]
}

model_list = ['CT FM (Ours)', 'Random Init.', 'SuPREM']

for name, label in labels.items():
    sub_df = df[df["Organ"].isin(label)]
    color_list = gray_palette

    sub_df['Model'] = pd.Categorical(sub_df['Model'], categories=['CT FM (Ours)', 'Random Init.', 'SuPREM'], ordered=True)

    # Calculate min and max Dice Score for setting plot range
    min_dice_score = sub_df.groupby(['Model', 'Organ'])['Dice Score'].mean().min() - 0.01
    max_dice_score = sub_df.groupby(['Model', 'Organ'])['Dice Score'].mean().max() + 0.01

    # Average Dice Score over labels
    avg_dice_score = sub_df.groupby(['Model', 'Organ'])['Dice Score'].mean().reset_index()
    avg_dice_score = avg_dice_score.groupby('Model')['Dice Score'].mean().reset_index()

    # Overall Dice Score comparison
    fig = px.bar(
        avg_dice_score,
        x='Model',
        y='Dice Score',
        color='Model',
        title=name,
        height=600,
        width=600,
        template='plotly_white',
        color_discrete_sequence=color_list,
        text='Dice Score'  # Add scores on top
    )

    # Add black borders to the bars
    fig.update_traces(marker=dict(line=dict(width=1, color='black')),
                      texttemplate='%{text:.3f}', textposition='outside')  # Format text on top
    fig.update_layout(
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(color='black', size=16),
        yaxis=dict(range=[min_dice_score, max_dice_score])  # Set y-axis range
    )
    fig.show()


In [None]:
res

In [41]:
results[0]['all_scores'].keys()

dict_keys(['spleen', 'kidney_right', 'kidney_left', 'gallbladder', 'liver', 'stomach', 'pancreas', 'adrenal_gland_right', 'adrenal_gland_left', 'lung_upper_lobe_left', 'lung_lower_lobe_left', 'lung_upper_lobe_right', 'lung_middle_lobe_right', 'lung_lower_lobe_right', 'esophagus', 'trachea', 'thyroid_gland', 'small_bowel', 'duodenum', 'colon', 'urinary_bladder', 'prostate', 'kidney_cyst_left', 'kidney_cyst_right', 'sacrum', 'vertebrae_S1', 'vertebrae_L5', 'vertebrae_L4', 'vertebrae_L3', 'vertebrae_L2', 'vertebrae_L1', 'vertebrae_T12', 'vertebrae_T11', 'vertebrae_T10', 'vertebrae_T9', 'vertebrae_T8', 'vertebrae_T7', 'vertebrae_T6', 'vertebrae_T5', 'vertebrae_T4', 'vertebrae_T3', 'vertebrae_T2', 'vertebrae_T1', 'vertebrae_C7', 'vertebrae_C6', 'vertebrae_C5', 'vertebrae_C4', 'vertebrae_C3', 'vertebrae_C2', 'vertebrae_C1', 'heart', 'aorta', 'pulmonary_vein', 'brachiocephalic_trunk', 'subclavian_artery_right', 'subclavian_artery_left', 'common_carotid_artery_right', 'common_carotid_artery_le

In [43]:
for label_group in class_map_5_parts:
    labels = list(class_map_5_parts[label_group].values())
    sub_df = df[df["Organ"].isin(labels)]

    color_list = gray_palette
    font_size = 20
    # Simple categorization
    sub_df["Model"] = pd.Categorical(sub_df["Model"], categories=["CT FM (Ours)", "Random Init.", "SuPREM"], ordered=True)

    sub_df["Upper CI"] = 0
    sub_df["Lower CI"] = 0

    def mean_ci(x):
        data = [score for score in x['Dice Score'] if not np.isnan(score) or np.isinf(score)]
        mean = np.nanmean(data)
        se = stats.sem(data)
        ci = stats.t.interval(0.95, len(x)-1, loc=mean, scale=se)
        return pd.Series([mean, ci[1] - mean, mean - ci[0]])

    mean_dice = sub_df.groupby(['Model', 'Group'])[['Dice Score', 'Upper CI', 'Lower CI']].apply(lambda x: mean_ci(x)).reset_index()
    mean_dice.columns = ['Model', 'Group', 'Dice Score', 'Upper CI', 'Lower CI']

    print(mean_dice)

    # Pairwise model comparisons using Wilcoxon signed-rank test
    model_pairs = [("CT FM (Ours)", "Random Init."), ("CT FM (Ours)", "SuPREM"), ("Random Init.", "SuPREM")]
    for model1, model2 in model_pairs:
        for result in results:
            if result["model"] == model1:
                results1 = result
            elif result["model"] == model2:
                results2 = result

        scores1 = [score for label in labels for score in results1['all_scores'][label] if not np.isnan(score) or np.isinf(score)]
        scores2 = [score for label in labels for score in results2['all_scores'][label] if not np.isnan(score) or np.isinf(score)]

        print(len(scores1), len(scores2))
        if len(scores1) == len(scores2):  # Ensure equal length for Wilcoxon test
            stat, pvalue = stats.ttest_rel(scores1, scores2)
            print(f"P-value for {model1} vs {model2}: {pvalue}")

    # Overall Dice Score comparison
    fig = px.bar(
        mean_dice,
        x='Group',
        y='Dice Score',
        color='Model',
        title=label_group.split("_")[-1].capitalize(),
        height=600,
        width=400,
        template='plotly_white',
        barmode="group",
        range_y=[np.min(mean_dice["Dice Score"] - mean_dice["Lower CI"]) * 0.95, np.max(mean_dice["Dice Score"] + mean_dice["Upper CI"]) * 1.05],
        color_discrete_sequence=color_list,
        error_y='Upper CI',
        error_y_minus='Lower CI'
    )

    # Update layout for a more minimalist look
    fig.update_layout(
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(color='black', size=font_size),
        title=dict(font=dict(size=font_size)),
        
        xaxis=dict(
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=False
        ),
        yaxis=dict(
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=False
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    # Remove gridlines
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=True)
    fig.update_traces(marker=dict(pattern=dict(shape="/")), selector=dict(name="AutoSeg3D*"))
    fig.update_traces(marker=dict(pattern=dict(shape="/")), selector=dict(name="nnUnet*"))
    fig.update_traces(marker=dict(pattern=dict(shape="/")), selector=dict(name="VISTA3D Auto*"))
    fig.show()


          Model     Group  Dice Score  Upper CI  Lower CI
0  CT FM (Ours)  vista_v2    0.855843  0.069134  0.069134
1  Random Init.  vista_v2    0.859624  0.064372  0.064372
2        SuPREM  vista_v2    0.829525  0.070463  0.070463
3670 3670
P-value for CT FM (Ours) vs Random Init.: 0.038561849821337896
3670 3670
P-value for CT FM (Ours) vs SuPREM: 3.324421425955082e-67
3670 3670
P-value for Random Init. vs SuPREM: 2.1648298163750495e-61


          Model     Group  Dice Score  Upper CI  Lower CI
0  CT FM (Ours)  vista_v2    0.901312  0.011456  0.011456
1  Random Init.  vista_v2    0.897133  0.013338  0.013338
2        SuPREM  vista_v2    0.884196  0.014770  0.014770
3430 3430
P-value for CT FM (Ours) vs Random Init.: 0.0004896416345427326
3430 3430
P-value for CT FM (Ours) vs SuPREM: 1.2689047213826765e-18
3430 3430
P-value for Random Init. vs SuPREM: 2.491775402834954e-12


          Model     Group  Dice Score  Upper CI  Lower CI
0  CT FM (Ours)  vista_v2    0.893326  0.019472  0.019472
1  Random Init.  vista_v2    0.890723  0.019907  0.019907
2        SuPREM  vista_v2    0.884081  0.017535  0.017535
2777 2777
P-value for CT FM (Ours) vs Random Init.: 8.984951540199968e-05
2777 2777
P-value for CT FM (Ours) vs SuPREM: 1.3226307362365186e-15
2777 2777
P-value for Random Init. vs SuPREM: 9.164593714846009e-10


          Model     Group  Dice Score  Upper CI  Lower CI
0  CT FM (Ours)  vista_v2    0.941215  0.011412  0.011412
1  Random Init.  vista_v2    0.938947  0.012263  0.012263
2        SuPREM  vista_v2    0.915667  0.015504  0.015504
3231 3231
P-value for CT FM (Ours) vs Random Init.: 0.003015800308568221
3231 3231
P-value for CT FM (Ours) vs SuPREM: 1.0686173187232562e-62
3231 3231
P-value for Random Init. vs SuPREM: 2.0811169449931986e-53


          Model     Group  Dice Score  Upper CI  Lower CI
0  CT FM (Ours)  vista_v2    0.899250  0.006858  0.006858
1  Random Init.  vista_v2    0.893748  0.008256  0.008256
2        SuPREM  vista_v2    0.840633  0.020978  0.020978
4461 4461
P-value for CT FM (Ours) vs Random Init.: 5.555299271336475e-07
4461 4461
P-value for CT FM (Ours) vs SuPREM: 1.4936588645051302e-151
4461 4461
P-value for Random Init. vs SuPREM: 7.300220074628607e-121


Empty DataFrame
Columns: [Model, Group, Dice Score, Upper CI, Lower CI]
Index: []


In [30]:
df

Unnamed: 0,Model,Group,Organ,Dice Score,Overall Dice,Macro Dice,Upper CI,Lower CI
0,CT FM (Ours),vista_v2,spleen,0.975054,66.484524,0.898142,0.002286,0.002286
85,CT FM (Ours),vista_v2,autochthon_left,0.963524,66.484524,0.898142,0.002286,0.002286
84,CT FM (Ours),vista_v2,gluteus_minimus_right,0.935229,66.484524,0.898142,0.002286,0.002286
83,CT FM (Ours),vista_v2,gluteus_minimus_left,0.932531,66.484524,0.898142,0.002286,0.002286
82,CT FM (Ours),vista_v2,gluteus_medius_right,0.958140,66.484524,0.898142,0.002286,0.002286
...,...,...,...,...,...,...,...,...
265,SuPREM,vista_v2,vertebrae_T12,0.907227,64.711894,0.869470,0.002659,0.002659
277,SuPREM,vista_v2,vertebrae_C7,0.920082,64.711894,0.869470,0.002659,0.002659
351,AutoSeg3D*,vista_v2,,,,0.882000,0.000000,0.000000
352,nnUnet*,vista_v2,,,,0.906000,0.000000,0.000000


In [68]:
mean_dice

Unnamed: 0,Model,Group,Dice Score,Upper CI,Lower CI


In [40]:
sub_df

Unnamed: 0,Model,Group,Organ,Dice Score,Overall Dice,Macro Dice,Upper CI,Lower CI


In [38]:
mean_dice


Unnamed: 0,Model,Group,level_2,Dice Score
0,CT FM (Ours),vista_v2,mean,0.855843
1,CT FM (Ours),vista_v2,ci_lower,0.791719
2,CT FM (Ours),vista_v2,ci_upper,0.919967
3,Random Init.,vista_v2,mean,0.859624
4,Random Init.,vista_v2,ci_lower,0.799917
5,Random Init.,vista_v2,ci_upper,0.91933
6,SuPREM,vista_v2,mean,0.829525
7,SuPREM,vista_v2,ci_lower,0.764169
8,SuPREM,vista_v2,ci_upper,0.894881


In [44]:
import plotly.express as px

# Prepare data for plotting
plot_data = []
for result in results:
    model_name = result["model"]
    group_name = result["group"]
    for organ, scores in result["all_scores"].items():
        for score in scores:
            plot_data.append({"Model": model_name, "Organ": organ, "Score": score, "Group": group})

# Convert to DataFrame
plot_df = pd.DataFrame(plot_data)

# Plot box plot for each organ using plotly express
fig_per_organ = px.box(plot_df, x="Score", y="Organ", color="Model", title='Dice Scores per Organ for Different Models', color_discrete_sequence=color_list, points="outliers")
fig_per_organ.update_layout(
    xaxis_title='Label',
    yaxis_title='Dice Score',
    legend_title='Model',
    xaxis_tickangle=-90,
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    height=1600,  # Increase the height of the plot
    width=1600,   # Increase the width of the plot

)
fig_per_organ.update_traces(marker=dict(line=dict(color='black', width=0.5)))

fig_per_organ.show()




In [27]:
plot_df

Unnamed: 0,Model,Organ,Score,Group
0,Random Init.,spleen,,few-shot_5
1,Random Init.,spleen,0.690636,few-shot_5
2,Random Init.,spleen,0.941055,few-shot_5
3,Random Init.,spleen,0.859484,few-shot_5
4,Random Init.,spleen,,few-shot_5
...,...,...,...,...
84586,SuPREM,costal_cartilages,0.000000,few-shot_5
84587,SuPREM,costal_cartilages,0.000000,few-shot_5
84588,SuPREM,costal_cartilages,0.000000,few-shot_5
84589,SuPREM,costal_cartilages,0.000000,few-shot_5


In [45]:
import matplotlib.pyplot as plt

def plot_3d_image(ret):
    # Plot axial slice
    ensure_list = lambda x: x if isinstance(x, list) else [x]
    ret = ensure_list(ret)
    
    plt.figure(figsize=(10, 10))
    for i in range(len(ret)):
        plt.subplot(len(ret), 3, i * 3 + 1)
        plt.imshow(ret[i][:, ret[i].shape[1] // 2, :, :].permute(1, 2, 0), cmap="gray")
        plt.title("Axial")
        plt.axis("off")

    for i in range(len(ret)):
        plt.subplot(len(ret), 3, i * 3 + 2)
        plt.imshow(ret[i][:, :, ret[i].shape[2] // 2, :].permute(1, 2, 0), cmap="gray")
        plt.title("Coronal")
        plt.axis("off")
        
    for i in range(len(ret)):
        plt.subplot(len(ret), 3, i * 3 + 3)
        plt.imshow(ret[i][:, :, :, ret[i].shape[3] // 2].permute(1, 2, 0), cmap="gray")
        plt.title("Sagittal")
        plt.axis("off")

    plt.tight_layout()
    plt.show()
                

In [49]:
from monai.visualize import blend_images


for idx in range(0, 5):
    for result in results:
        print(result["group"], result["model"], result["overall_dice"])

        sample = result["image_samples"][idx]
        image = sample["image"].squeeze(0).cpu()
        label = sample["label"].squeeze(0).cpu()
        pred = sample["pred"].squeeze(0).cpu()
        sid = sample["sid"]

        out_channels = )
        
        label_max = out_channels - 1

        label = label / label_max
        pred = pred / label_max

        image = (image - image.min()) / (image.max() - image.min())
        
        ret_pred = blend_images(image=image, label=pred, alpha=0.3, cmap="hsv", rescale_arrays=False)
        ret_label = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
        plot_3d_image([ret_pred, ret_label])
        



vista_v2 CT FM (Ours) 66.48452369628414


## Significance Testing

In [53]:
import matplotlib.pyplot as plt
from scipy.stats import mannwhitneyu, ttest_rel, ttest_ind, wilcoxon
import numpy as np

def flatten_scores(scores):
    flattened = []
    for organ_scores in scores.values():
        flattened.extend(organ_scores)
    return flattened

for entry in results:
    flattened_scores = flatten_scores(entry['all_scores'])
    for entry_to_compare in results:
        if entry == entry_to_compare:
            continue
        if entry["group"] != entry_to_compare["group"]:
            continue

        flattened_compare_scores = flatten_scores(entry_to_compare['all_scores'])

        if flattened_scores and flattened_compare_scores:
            u_stat, p_value = wilcoxon(flattened_scores, flattened_compare_scores, alternative='two-sided', nan_policy='omit')
            print(f"Wilcoxon test between {entry['model']} and {entry_to_compare['model']} and {entry['group']}: U-statistic = {u_stat}, p-value = {p_value}")
        else:
            print(f"No valid scores to compare between {entry['model']} and {entry_to_compare['model']}")


29016
Wilcoxon test between Random Init. and CT FM (Ours) and few-shot_20: U-statistic = 3607468.5, p-value = 0.0
Wilcoxon test between Random Init. and SuPREM and few-shot_20: U-statistic = 13248355.0, p-value = 7.958366482676549e-155
29016
Wilcoxon test between SuPREM and Random Init. and few-shot_50: U-statistic = 9234273.5, p-value = 0.0
Wilcoxon test between SuPREM and CT FM (Ours) and few-shot_50: U-statistic = 5687688.5, p-value = 0.0
29016
Wilcoxon test between Random Init. and SuPREM and few-shot_50: U-statistic = 9234273.5, p-value = 0.0
Wilcoxon test between Random Init. and CT FM (Ours) and few-shot_50: U-statistic = 53081903.0, p-value = 7.344650660104694e-222
29016
Wilcoxon test between CT FM (Ours) and SuPREM and few-shot_100: U-statistic = 42148726.5, p-value = 0.0
Wilcoxon test between CT FM (Ours) and Random Init. and few-shot_100: U-statistic = 72598075.0, p-value = 9.359182450363814e-05
29016
Wilcoxon test between CT FM (Ours) and Random Init. and few-shot_20: U-sta