In [None]:
%load_ext autoreload
%autoreload 2
from trainer import LitModel
import torch 
from shared_modules.data_module_all import DataModule
from shared_modules.utils import load_config
from monai.transforms import AsDiscrete
from tqdm import tqdm
from monai.metrics import DiceMetric
from shared_modules.torch_metrics import PicaiMetric

In [17]:

config = load_config("config.yaml")
config.data.data_dir = "../../../data/"
gpu = 0
config.gpus = [gpu]
config.cache_rate = 1.0
config.transforms.label_keys=["pca", "zones"] 
config.transforms.crop_key = None 
config.transforms.roi_size = [256,256,20]
config.transforms.spatial_size = [256,256,20]
config.transforms.image_keys = ["t2w", "adc", "hbv"]
config.data.json_list = "../../../json_datalists/p158/all_samples.json"
config.num_workers = 16
weights_folder = "../../../gc_algorithms/base_container/models/umamba_mtl/weights"


models = []

for i in range(5):
    models.append(LitModel.load_from_checkpoint(f"{weights_folder}/f{i}.ckpt", config=config, map_location=f"cuda:{gpu}"))
    # disable randomness, dropout, etc...
    models[-1].eval()
    models[-1].to(gpu)

In [None]:
dm = DataModule(
    config=config,
)

dm.setup("test")
dl = dm.test_dataloader()

In [None]:
dsc_fn = DiceMetric(include_background=True, reduction="none")
picai_metric_fn = PicaiMetric()

all_probs = []
all_gts = []
all_zones = []


for batch in tqdm(dl):
    with torch.no_grad():
        x = batch["image"].to(gpu)

        preds = []
        probs = []
        fold_zones = []
        prostates = []
        
        for fold, model in enumerate(models):
            logits = model.inferer(x)
            
            probs.append(torch.sigmoid(logits[0,1])[None][None])
            preds.append(AsDiscrete(threshold=0.5)(probs[-1][0])[None])
            fold_zones.append(torch.softmax(logits[0,2:,...], dim=0)[None][:,1:,...])     
            
        
    batch["pred"] = (torch.mean(torch.stack(preds), dim=0) > 0.5).float()
    batch["prob"] = torch.mean(torch.stack(probs), dim=0)
    batch["zones_pred"] = (torch.mean(torch.stack(fold_zones), dim=0) > 0.5).float()
    batch["prostate_pred"] = (torch.sum(batch["zones_pred"],dim=1) > 0.5).float()[None]
        
    # Need to flip TZ and PZ to match prediction order
    batch["zones"] = torch.stack([batch["zones"][:,2,...],batch["zones"][:,1,...]], dim=1)
    batch["prostate"] = (torch.sum(batch["zones"],dim=1) > 0.5).float()[None]

    all_preds = torch.concat([batch["prostate_pred"].to(gpu), batch["zones_pred"].to(gpu)], dim=1).to(gpu)
    all_gt = torch.concat([batch["prostate"].to(gpu), batch["zones"].to(gpu)], dim=1).to(gpu)
    dsc = dsc_fn(y_pred=all_preds, y=all_gt)

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

if "zones" in config.transforms.label_keys:

    dsc_metrics = dsc_fn.aggregate("none")
    df = pd.DataFrame({
                    "Prostate": dsc_metrics[:,0].cpu(),
                    "PZ": dsc_metrics[:,1].cpu(),
                    "TZ": dsc_metrics[:,2].cpu(),
                    
                    })

    # Melt the DataFrame to long format for Seaborn
    df_long = df.melt(var_name="Region", value_name="DSC")

    # Plot the violin plot
    palette = sns.color_palette("muted")
    ax = sns.violinplot(data=df_long, x="Region", y="DSC", palette=palette)

    # Calculate mean DSC scores
    means = df.mean()

    #Calculate Standard deviation
    stds = df.std()

    # Add a custom legend that matches the colors
    handles = [
        plt.Line2D([0], [0], color=color, lw=4, label=f"{col}: {mean:.2f} ± {std:.2f}")
        for color, col, mean, std in zip(palette, df.columns, means, stds)
    ]
    ax.legend(handles=handles, title="Mean DSC Scores", loc="lower left")

    # Set the title
    ax.set_title("DSC Distribution")

    # Show the plot
    plt.show()