In [1]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import SimpleITK as sitk
import torch
from monai.metrics import DiceMetric, SurfaceDistanceMetric
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureChannelFirst,
    EnsureType,
    Lambda,
    LoadImage,
    Orientation,
    ToDevice,
    Transpose,
    LabelFilterd,
    MapLabelValued
    
    
)

from lighter.utils.dynamic_imports import import_module_from_path
from pathlib import Path
from tqdm import tqdm
from totalsegmentator.map_to_binary import class_map

import_module_from_path("project", "/home/suraj/Repositories/lighter-ct-fm")
from project.data import get_ts_class_indices, get_ts_class_labels

In [2]:
label_map = class_map["total"]
pred_dir = Path("/mnt/data1/CT_FM/evaluations/totalseg/predictions")
dataset_path = Path("/mnt/data1/TotalSegmentator/v2/processed")

In [3]:
results = []

for model_dir in pred_dir.glob("*"):
    group = "_".join(model_dir.name.split("_")[-2:])
    model_name = "_".join(model_dir.name.split("_")[:-2])

    print(f"Evaluating... Group: {group}, Model: {model_name}\n")
    class_indices = get_ts_class_indices(group=group)
    class_labels = get_ts_class_labels(class_indices, group=group)
    out_channels = len(class_indices)

    dice = DiceMetric(include_background=True, num_classes=out_channels, reduction="none")

    base_transforms = Compose([
        LoadImage(),
        ToDevice(device="cuda"),
        EnsureChannelFirst(),
        EnsureType(data_type="tensor", dtype="int"),
        Orientation(axcodes="SPL" if "suprem" not in model_name else "RAS"),
    ])

    mapping_transforms = Compose([
        Lambda(lambda x: {"label": x}),
        LabelFilterd(keys="label", applied_labels=class_indices),
        MapLabelValued(keys="label", orig_labels=class_indices, target_labels=list(range(out_channels))),
        Lambda(lambda x: x["label"])
    ])

    target_transforms = Compose([base_transforms, mapping_transforms])

    print("Calculating Dice Scores... \n")
    dice_dict = {label: [] for label in class_labels}
    for pred_path in model_dir.glob("*"):
        sid = pred_path.stem
        label = target_transforms(dataset_path / sid / "label.nii.gz").unsqueeze(0)
        pred = base_transforms(pred_path).unsqueeze(0)
        res = dice(pred, label).squeeze().tolist()
        for label, score in zip(class_labels, res):
            dice_dict[label].append(score)
            
    dice_dict.pop("background", None)
    dice_dict = {k: np.nanmean(v) for k, v in dice_dict.items()}

    overall_dice = np.nanmean(list(dice_dict.values()))
    results.append({"group": group, "model": model_name, "dice_scores": dice_dict, "overall_dice": overall_dice})



Evaluating... Group: merlin_v2, Model: baseline

Calculating Dice Scores... 

Evaluating... Group: merlin_v2, Model: ct_fm

Calculating Dice Scores... 



In [20]:
# Convert the data into a pandas DataFrame
rows = []
for entry in results:
    for organ, dice_score in entry['dice_scores'].items():
        rows.append({
            'Model': entry['model'],
            'Group': entry['group'],
            'Organ': organ,
            'Dice Score': dice_score,
            'Overall Dice': entry['overall_dice']
        })

df = pd.DataFrame(rows)
    


# Per class Dice Score comparison
fig_per_class = px.bar(df, x='Organ', y='Dice Score', color='Model', title='Per Class Dice Score Comparison',
                        barmode='group')
fig_per_class.show()


# Overall Dice Score comparison
fig = px.bar(df[["Model", "Overall Dice", "Group"]].drop_duplicates(), x='Group', y='Overall Dice', color='Model', title='Overall Dice Comparison', barmode='group', height=800, width=400)
fig.show()