In [1]:
from pathlib import Path

import pandas as pd

In [2]:
def merge_epochs(df: pd.DataFrame):
    merged = df.groupby("epoch").agg(lambda x: x.dropna().iloc[0] if not x.dropna().empty else None).reset_index()

    # Save the cleaned version
    return merged

In [3]:
def load_metrics_from_folder(folder_path: Path) -> dict:
    merged_results = {}

    # Look for all CSV files in the folder
    for file in folder_path.glob("*.csv"):
        model_name = file.stem

        # Read and merge metrics
        df = pd.read_csv(file)
        merged = merge_epochs(df)

        # Add to dictionary
        merged_results[model_name] = merged

    return merged_results

# Example usage
folder = Path("./data/results")

results_dict = load_metrics_from_folder(folder)

In [4]:
results_dict.keys()

dict_keys(['panda_gray', 'rings_color', 'panda_color', 'cocahis_color', 'cocahis_color_local_dimip', 'cocahis_color_local_ms', 'rings_gray', 'cocahis_gray'])

In [None]:
best_epoch = results_dict['panda_gray']['val_dice'].idxmax()

display(results_dict['panda_gray'].iloc[best_epoch])

results_dict['panda_gray']

In [9]:
best_epoch = results_dict['cocahis_color_local_ms']['val_dice'].idxmax()

display(results_dict['cocahis_color_local_ms'].iloc[best_epoch])

results_dict['cocahis_color_local_ms']

epoch          48.000000
step          293.000000
train_loss      0.401287
val_dice        0.784816
val_iou         0.663430
val_loss        0.395101
Name: 48, dtype: float64

Unnamed: 0,epoch,step,train_loss,val_dice,val_iou,val_loss
0,0,5,0.602331,0.306176,0.189185,0.643621
1,1,11,0.570560,0.307143,0.189852,0.597301
2,2,17,0.562922,0.307154,0.189864,0.582526
3,3,23,0.556548,0.306590,0.189474,0.575406
4,4,29,0.550968,0.306556,0.189448,0.571554
...,...,...,...,...,...,...
66,66,401,0.396204,0.672792,0.544578,0.420592
67,67,407,0.398184,0.675553,0.547729,0.421394
68,68,413,0.397511,0.669195,0.540541,0.426684
69,69,419,0.395318,0.671922,0.543379,0.424580


In [11]:
best_epoch = results_dict['panda_gray']['val_dice'].idxmax()

display(results_dict['panda_gray'].iloc[best_epoch])

results_dict['panda_gray']

epoch                 0.000000
step                832.000000
train_epoch_time    661.948364
train_loss            1.125687
val_dice              0.390480
val_epoch_time        4.705945
val_iou               0.129954
val_loss              1.251489
Name: 0, dtype: float64

Unnamed: 0,epoch,step,train_epoch_time,train_loss,val_dice,val_epoch_time,val_iou,val_loss
0,0,832,661.948364,1.125687,0.39048,4.705945,0.129954,1.251489
1,1,1665,660.518921,0.923002,0.39048,4.645448,0.076400,2.901461
2,2,2498,661.834961,0.884415,0.39048,4.617863,0.176912,0.887022
3,3,3331,661.280273,0.844891,0.39048,4.645880,0.185459,0.860156
4,4,4164,664.402283,0.799319,0.39048,4.639325,0.069459,4.678411
...,...,...,...,...,...,...,...,...
71,71,59975,664.265564,0.648409,0.39048,4.748538,0.313245,0.617653
72,72,60808,666.048401,0.648319,0.39048,4.733882,0.313953,0.617308
73,73,61641,664.677368,0.648239,0.39048,4.740481,0.313451,0.618073
74,74,62474,663.068176,0.648474,0.39048,4.731079,0.313059,0.617967


In [10]:
best_epoch = results_dict['panda_color']['val_dice'].idxmax()

display(results_dict['panda_color'].iloc[best_epoch])

results_dict['panda_color']

epoch                  0.000000
step                 832.000000
train_epoch_time    1010.515930
train_loss             0.955142
val_dice               0.390480
val_epoch_time        10.627349
val_iou                0.086345
val_loss               5.726986
Name: 0, dtype: float64

Unnamed: 0,epoch,step,train_epoch_time,train_loss,val_dice,val_epoch_time,val_iou,val_loss
0,0,832,1010.51593,0.955142,0.39048,10.627349,0.086345,5.726986
1,1,1665,1004.237122,0.762887,0.39048,9.934948,0.214401,0.846521
2,2,2498,1013.854065,0.712067,0.39048,10.020712,0.263231,0.695357
3,3,3331,1022.203735,0.724197,0.39048,10.347104,0.084119,3.824173
4,4,4164,1017.329224,0.667566,0.39048,10.352876,0.279134,0.674441
5,5,4997,1016.007446,0.654681,0.39048,10.350669,0.221952,0.893198
6,6,5830,1017.957581,0.647736,0.39048,10.358093,0.228043,0.852222
7,7,6663,1015.831299,0.631182,0.39048,10.344902,0.339127,0.588202
8,8,7496,1023.972961,0.629101,0.39048,9.995182,0.297191,0.654823
9,9,8329,1018.764465,0.626479,0.39048,10.362598,0.328222,0.604228


In [6]:
import torch
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ToTensor
from torchvision.io import read_image
import matplotlib.pyplot as plt
from pathlib import Path

from src.ml_utils.machine_learning import UNetLightning, init_unet_model, normalize_img_for_plot
from src.ml_utils.preprocessing import HENormalization, ToGrayscale
import torchstain

@torch.no_grad()
def run_model_on_image(
    image_path: Path,
    ckpt_path: Path,
    color_mode: str = "color",
    classes: int = 1,
    normalizer_image_path: Path | None = None,
    label_path: Path | None = None,
):
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model (instantiate and load state)
    model = UNetLightning(val_loader=None, original_files=[], color_mode=color_mode, classes=classes)
    model = model.load_from_checkpoint(ckpt_path, val_loader=None, original_files=[], color_mode=color_mode, classes=classes)
    model.eval().to(device)

    # Prepare transforms
    transforms = [LoadImage(image_only=True), EnsureChannelFirst(), ToTensor()]

    if color_mode == "color":
        normalizer = torchstain.normalizers.ReinhardNormalizer(method='modified', backend='torch')
        normalizer.fit(read_image(normalizer_image_path))
        transforms.insert(2, HENormalization(keys=["img"], normalizer=normalizer, method="reinhard"))
    else:
        transforms.insert(2, ToGrayscale(keys=["img"]))

    transform = Compose(transforms)

    # Apply transform
    img = transform({"img": image_path})["img"]
    img_tensor = img.unsqueeze(0).to(device)

    # Inference
    output = model(img_tensor)
    if classes == 1:
        pred = torch.sigmoid(output) > 0.5
    else:
        pred = torch.argmax(torch.softmax(output, dim=1), dim=1, keepdim=True)

    pred_np = pred.squeeze().cpu().numpy()
    img_np = normalize_img_for_plot(img.squeeze().cpu())

    # Optional GT
    gt_np = None
    if label_path is not None:
        gt = LoadImage(image_only=True)(label_path)
        gt = EnsureChannelFirst()(gt)
        if classes > 1:
            gt = ToTensor()(gt.long())
        else:
            gt = ToTensor()(gt.float())
        gt_np = gt.squeeze().numpy()

    # Visualization
    fig, axs = plt.subplots(1, 3 if gt_np is not None else 2, figsize=(12, 4))
    axs[0].imshow(img_np, cmap="gray" if img_np.ndim == 2 else None)
    axs[0].set_title("Input")
    axs[0].axis("off")

    if gt_np is not None:
        axs[1].imshow(gt_np, cmap="gray")
        axs[1].set_title("Ground Truth")
        axs[1].axis("off")
        axs[2].imshow(pred_np, cmap="gray")
        axs[2].set_title("Prediction")
        axs[2].axis("off")
    else:
        axs[1].imshow(pred_np, cmap="gray")
        axs[1].set_title("Prediction")
        axs[1].axis("off")

    plt.tight_layout()
    plt.show()