In [1]:
from data_modules.ocelot import OcelotDataModule
from models.dynunet import DynUNetModel
import monai.transforms as T
from lightning.pytorch import seed_everything
import torch
from system import System
from monai.inferers import SlidingWindowInferer
from monai.metrics import (
    DiceMetric,
    HausdorffDistanceMetric,
)
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from ocelot_util import (
    normalize_crop_coords_batch,
    evaluate_cell_detection_batch, 
    cell_detection_postprocessing_batch, 
    calculate_metrics, 
    load_ground_truth,
)
import random

seed_everything(42)

[rank: 0] Seed set to 42


42

In [2]:
print(torch.cuda.is_available())

True


In [3]:
path = './lightning_logs/ib1kxueb/checkpoints/model-epoch=143-val_loss=0.74.ckpt'
name = 'SwinUNETR'
cell_classifier = System.load_from_checkpoint(checkpoint_path=path)


monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().
Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.


In [4]:
data_module = OcelotDataModule(batch_size=1, num_workers=1)
data_module.prepare_data()
data_module.setup()

test_loader = data_module.test_dataloader()

Train: 400 | Val: 137 | Test: 126


In [5]:
os.makedirs("ocelot_cell/best", exist_ok=True)
os.makedirs("ocelot_cell/worst", exist_ok=True)
os.makedirs("ocelot_cell/average", exist_ok=True)
os.makedirs("ocelot_cell/random", exist_ok=True)


In [6]:
cell_post_transform = T.Activations(softmax = True,  dim=1)

In [7]:
from skimage.feature import peak_local_max
import numpy as np
import cv2
from scipy.spatial import cKDTree

def evaluate_cell_detection_with_coords(pred_cells, pred_classes, gt_cells, gt_classes, max_distance=15):
    matched_gt_cells = []
    matched_gt_classes = []

    unmatched_gt_cells = []
    unmatched_gt_classes = []

    unmatched_pred_cells = []
    unmatched_pred_classes = []
    
    pred_cells = pred_cells[0]
    pred_classes = pred_classes[0]
    gt_cells = gt_cells[0]
    gt_classes = gt_classes[0]

    for class_val in [1, 2]:
        class_gt_cells = gt_cells[gt_classes == class_val]
        class_pred_cells = pred_cells[pred_classes == class_val]

        gt_tree = cKDTree(class_gt_cells)
        matched_gt_indices = set()

        for pred_idx, pred in enumerate(class_pred_cells):
            neighbors = gt_tree.query_ball_point(pred, max_distance)
            available = [i for i in neighbors if i not in matched_gt_indices]

            if available:
                closest_idx = min(available, key=lambda i: np.linalg.norm(pred - class_gt_cells[i]))
                matched_gt_indices.add(closest_idx)
                matched_gt_cells.append(class_gt_cells[closest_idx])
                matched_gt_classes.append(class_val)
            else:
                unmatched_pred_cells.append(pred)
                unmatched_pred_classes.append(class_val)

        # Unmatched ground truths
        for i, gt_cell in enumerate(class_gt_cells):
            if i not in matched_gt_indices:
                unmatched_gt_cells.append(gt_cell)
                unmatched_gt_classes.append(class_val)

    # Convert to numpy arrays for plotting
    matched_gt_cells = np.array(matched_gt_cells)
    matched_gt_classes = np.array(matched_gt_classes)

    unmatched_gt_cells = np.array(unmatched_gt_cells)
    unmatched_gt_classes = np.array(unmatched_gt_classes)

    unmatched_pred_cells = np.array(unmatched_pred_cells)
    unmatched_pred_classes = np.array(unmatched_pred_classes)

    return matched_gt_cells, matched_gt_classes, unmatched_gt_cells, unmatched_gt_classes, unmatched_pred_cells, unmatched_pred_classes
    
def cell_detection_postprocessing(y_tc, y_bc, y_bg, min_distance=3):
    if isinstance(y_tc, torch.Tensor):
        y_tc = y_tc.cpu().detach().numpy()
    if isinstance(y_bc, torch.Tensor):
        y_bc = y_bc.cpu().detach().numpy()
    if isinstance(y_bg, torch.Tensor):
        y_bg = y_bg.cpu().detach().numpy()
    # Compute foreground probability
    foreground = 1 - y_bg
    foreground = cv2.GaussianBlur(foreground, (0, 0), sigmaX=3)
    # Detect peaks (local maxima)
    cell_candidates = peak_local_max(foreground, min_distance=min_distance, exclude_border=0, threshold_abs=0.0)

    # Store valid cells, classes, and confidence scores
    valid_cells = []
    valid_classes = []
    confidence_scores = []
    
    for x, y in cell_candidates:
        if y_tc[x, y] > y_bg[x, y] or y_bc[x, y] > y_bg[x, y]:  
            valid_cells.append((x, y))
            
            # Determine the predicted class (1 = background cell, 2 = tumor cell)
            if y_bc[x, y] > y_tc[x, y]:
                cell_class = 1  # Background cell
            else:
                cell_class = 2  # Tumor cell
            
            valid_classes.append(cell_class)

            # Confidence score is the max probability of either class
            confidence_score = max(y_tc[x, y], y_bc[x, y])
            confidence_scores.append(confidence_score)

    # Convert to numpy arrays
    valid_cells = np.array(valid_cells)
    valid_classes = np.array(valid_classes)
    confidence_scores = np.array(confidence_scores)

    # Sort by confidence score (descending order)
    sorted_indices = np.argsort(-confidence_scores)
    return valid_cells[sorted_indices], valid_classes[sorted_indices], confidence_scores[sorted_indices]

def cell_detection_postprocessing_batch(y_c_batch, min_distance=3):

    cells_list = []
    classes_list = []
    confidences_list = []

    for y_c in y_c_batch:
        # Extract individual channels
        y_bg, y_bc, y_tc = y_c[0], y_c[1], y_c[2]

        # Process each batch item individually
        cells, classes, confidences = cell_detection_postprocessing(y_tc, y_bc, y_bg, min_distance=min_distance)
        cells_list.append(cells)
        classes_list.append(classes)
        confidences_list.append(confidences)

    return cells_list, classes_list, confidences_list

In [8]:
best_f1 = -1
worst_f1 = 2

best_sample = None
worst_sample = None
all_samples = []

for i, batch in enumerate(tqdm(test_loader, desc="Evaluating")):
    gt_cells_list, gt_classes_list = load_ground_truth(batch)
    meta_batch = batch["meta"]
    cropped_coords = normalize_crop_coords_batch(meta_batch)
    
    _, cell_pred = cell_classifier.net.forward(batch, cropped_coords, train_mode = False)
    cell_pred = cell_post_transform(cell_pred)

    # Get predicted cells and their classes from the model outputs
    pred_cells_list, pred_classes_list, confidences_list = cell_detection_postprocessing_batch(cell_pred)

    matched_gt_cells, matched_gt_classes, unmatched_gt_cells, unmatched_gt_classes, unmatched_pred_cells, unmatched_pred_classes = evaluate_cell_detection_with_coords(
            pred_cells_list, pred_classes_list, gt_cells_list, gt_classes_list, max_distance=15)

    total_tp_bc, total_fp_bc, total_fn_bc, total_tp_tc, total_fp_tc, total_fn_tc = evaluate_cell_detection_batch(
                pred_cells_list, pred_classes_list, gt_cells_list, gt_classes_list,
                )
    
    precision_bc, recall_bc, f1_bc = calculate_metrics(total_tp_bc, total_fp_bc, total_fn_bc)
    precision_tc, recall_tc, f1_tc = calculate_metrics(total_tp_tc, total_fp_tc, total_fn_tc)

    mean_f1 = (f1_bc + f1_tc) / 2
    #print(f"BC: {f1_bc}")
    #print(f"TC: {f1_tc}")


    unique_classes = set([c for sublist in gt_classes_list for c in sublist])
    if 1 in unique_classes and 2 in unique_classes:
        sample = {
            "index": i,
            "image": batch["img_cell"],
            "mean_f1": mean_f1,
            "gt_classes_list": gt_classes_list,
            "pred_classes_list": pred_classes_list,
            "gt_cells_list": gt_cells_list,
            "pred_cells_list": pred_cells_list,
            "matched_gt_list": [matched_gt_cells],
            "matched_gt_classes_list": [matched_gt_classes],
            "unmatched_gt_list": [unmatched_gt_cells],
            "unmatched_gt_classes_list": [unmatched_gt_classes],
            "wrong_pred_cells_list": [unmatched_pred_cells],
            "wrong_pred_classes_list": [unmatched_pred_classes],
        }

        all_samples.append(sample)

        if mean_f1 > best_f1:
            best_f1 = mean_f1
            best_sample = sample

        if mean_f1 < worst_f1:
            worst_f1 = mean_f1
            worst_sample = sample


Evaluating: 100%|██████████| 126/126 [01:54<00:00,  1.10it/s]


In [9]:
#random_samples = random.sample(all_samples, 3) if len(all_samples) >= 3 else all_samples
random_indices = [81, 14, 3]
random_samples = [all_samples[i] for i in random_indices]

In [10]:
import os
import matplotlib.pyplot as plt

def save_cell_visualization(sample, save_path, title=""):
    image_tensor = sample["image"]  # shape: [B, C, H, W]
    index = sample["index"]

    # Visualize only the first image in the batch
    image = image_tensor[0].detach().cpu() 
    image_np = image.permute(1, 2, 0).numpy()  # [H, W, C]
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())  # normalize to [0,1]

    # Setup plot
    plt.figure(figsize=(10, 10))
    plt.imshow(image_np)
    plt.axis("off")
    plt.title(f"Mean F1: {sample['mean_f1']:.2f}")

    for (y, x), cls in zip(sample["matched_gt_list"][0], sample["matched_gt_classes_list"][0]):
        if cls == 1:
            plt.plot(x, y, 'o', color='g', markersize=3)
        if cls == 2:
            plt.plot(x, y, 'o', color='r', markersize=3)
    # Plot FNs (missed GTs)
    for (y, x), cls in zip(sample["unmatched_gt_list"][0], sample["unmatched_gt_classes_list"][0]):
        color = 'b' if cls == 1 else 'orange'
        plt.plot(x, y, 'o', color=color, markersize=3)

    # Plot FPs (wrong predictions)
    for (y, x), cls in zip(sample["wrong_pred_cells_list"][0], sample["wrong_pred_classes_list"][0]):
        color = 'g' if cls == 1 else 'r'
        marker = 'x'
        plt.plot(x, y, marker, color=color, markersize=3)


    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=1024, bbox_inches='tight', pad_inches=0)
    plt.close()


In [11]:
save_cell_visualization(best_sample, f"ocelot_cell/best/{name}_sample_best.png", title="Best Sample")
save_cell_visualization(worst_sample, f"ocelot_cell/worst/{name}_sample_worst.png", title="Worst Sample")
i = 0
for sample in random_samples:
    save_cell_visualization(sample, f"ocelot_cell/random/{name}_random_sample_{i}.png", title=f"Random Sample {i}")
    i +=1
