## Inference on the validation set with the trained model

In [10]:
import os
import os.path as osp
import numpy as np
import torch
from dataclasses import dataclass

from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from torch.utils.data import DataLoader

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline

In [11]:
def data_loader(args):
    dataset = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        random_sampling=False,
    )
    loader = DataLoader(
        dataset, batch_size=args.batch_size, drop_last=False, num_workers=0
    )
    return loader


def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_pred = largest_connected_component(volume_pred)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes


def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
    return dsc_dict


def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.frombuffer(s, np.uint8).reshape((height, width, 4))


def makedirs(args):
    os.makedirs(args.predictions, exist_ok=True)


def plot_accuracy(acc_dist):
    y_pos = np.arange(len(acc_dist))
    items = sorted(acc_dist.items(), key=lambda x: x[1])
    values = [v for _, v in items]
    labels = ["_".join(k.split("_")[1:-1]) for k, _ in items]

    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_pos, values, align="center", color="lightgray")
    plt.yticks(y_pos, labels)
    plt.xticks(np.arange(0.9, 1.0, 0.01))
    plt.xlim([0.9, 1.0])
    plt.axvline(np.mean(values), color="tomato", linewidth=2)
    plt.axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Accuracy", fontsize="x-large")
    plt.grid(axis="x", linestyle="--", alpha=0.5)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (w, h) = canvas.print_to_buffer()
    return np.frombuffer(s, np.uint8).reshape((h, w, 4))

def acc_distribution(volumes):
    acc_dist = {}
    for p, (vol_in, vol_pred, vol_true) in volumes.items():
        # vol_pred and vol_true are binary masks (0 or 1) after postprocess
        # Flatten to 1D arrays
        pred_flat = vol_pred.flatten()
        true_flat = vol_true.flatten()

        # Count TP, TN, FP, FN
        tp = np.logical_and(pred_flat == 1, true_flat == 1).sum()
        tn = np.logical_and(pred_flat == 0, true_flat == 0).sum()
        fp = np.logical_and(pred_flat == 1, true_flat == 0).sum()
        fn = np.logical_and(pred_flat == 0, true_flat == 1).sum()

        # Avoid division by zero
        total = tp + tn + fp + fn + 1e-8
        acc_dist[p] = (tp + tn) / total
    return acc_dist

In [12]:
@dataclass
class Args:
    device = 'cuda:0'
    batch_size = 32
    weights = './weights/1-unet.pt'
    images = './BrainMRI/kaggle_3m'
    image_size = 256
    predictions = './predictions'
    dsc_figure = './report/img/validation_dsc_2.png'
    accuracy_figure = './report/img/validation_accuracy_2.png'
    logs = "./logs"

args = Args()

In [13]:
assert osp.exists(args.images), "Please download the dataset and set the correct path" 

makedirs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

loader = data_loader(args)

with torch.set_grad_enabled(False):
    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    state_dict = torch.load(args.weights, map_location=device)
    unet.load_state_dict(state_dict)
    unet.eval()
    unet.to(device)

    input_list = []
    pred_list = []
    true_list = []

    for i, data in enumerate(loader):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)

        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

volumes = postprocess_per_volume(
    input_list,
    pred_list,
    true_list,
    loader.dataset.patient_slice_index,
    loader.dataset.patients,
)

# DSC distribution
dsc_dist = dsc_distribution(volumes)

dsc_dist_plot = plot_dsc(dsc_dist)
imsave(args.dsc_figure, dsc_dist_plot)

# Accuracy distribution
acc_dist = acc_distribution(volumes)

acc_plot = plot_accuracy(acc_dist)
imsave(args.accuracy_figure, acc_plot)

for p in volumes:
    acc = acc_dist[p]
    dsc_value = dsc_dist[p]
    print(f"Patient {p}: Accuracy = {acc:.4f}, DSC = {dsc_value:.4f}")

reading validation images...
Load dataset from cache: .cache\validation.pkl
done creating validation dataset
Patient kaggle_3m\TCGA_HT_7616_19940813: Accuracy = 0.9929, DSC = 0.7638
Patient kaggle_3m\TCGA_CS_6668_20011025: Accuracy = 0.9859, DSC = 0.0000
Patient kaggle_3m\TCGA_CS_4944_20010208: Accuracy = 0.9932, DSC = 0.8842
Patient kaggle_3m\TCGA_HT_7879_19981009: Accuracy = 0.9986, DSC = 0.9086
Patient kaggle_3m\TCGA_DU_7014_19860618: Accuracy = 0.9966, DSC = 0.9002
Patient kaggle_3m\TCGA_DU_6408_19860521: Accuracy = 0.9969, DSC = 0.9386
Patient kaggle_3m\TCGA_DU_6404_19850629: Accuracy = 0.9992, DSC = 0.9346
Patient kaggle_3m\TCGA_DU_5851_19950428: Accuracy = 0.9989, DSC = 0.9271
Patient kaggle_3m\TCGA_CS_6667_20011105: Accuracy = 0.9990, DSC = 0.9294
Patient kaggle_3m\TCGA_HT_7692_19960724: Accuracy = 0.9994, DSC = 0.9063


In [14]:
s = 0
for p in volumes:
    s += dsc_dist[p]
print(f"Mean DSC: {s / len(volumes):.4f}")

Mean DSC: 0.8093


In [None]:
for p in volumes:
    x = volumes[p][0]
    y_pred = volumes[p][1]
    y_true = volumes[p][2]
    for s in range(x.shape[0]):
        image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
        image = outline(image, y_pred[s, 0], color=[255, 0, 0])
        image = outline(image, y_true[s, 0], color=[0, 255, 0])
        filename = "{}-{}.png".format(p, str(s).zfill(2))
        filepath = os.path.join(args.predictions, filename)
        dirpath = os.path.dirname(filepath)
        os.makedirs(dirpath, exist_ok=True)
        imsave(filepath, image)