## Inference on the validation set with the trained model

In [1]:
import os
import os.path as osp
import numpy as np
import torch
from dataclasses import dataclass

from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from torch.utils.data import DataLoader

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline

In [2]:
def data_loader(args):
    dataset = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        random_sampling=False,
    )
    loader = DataLoader(
        dataset, batch_size=args.batch_size, drop_last=False, num_workers=1
    )
    return loader


def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_pred = largest_connected_component(volume_pred)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes


def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
    return dsc_dict


def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.frombuffer(s, np.uint8).reshape((height, width, 4))


def makedirs(args):
    os.makedirs(args.predictions, exist_ok=True)

In [3]:
@dataclass
class Args:
    device = 'cuda:0'
    batch_size = 32
    weights = './weights/unet.pt'
    images = './BrainMRI/kaggle_3m'
    image_size = 256
    predictions = './predictions'
    figure = './dsc.png'

args = Args()

In [4]:
assert osp.exists(args.images), "Please download the dataset and set the correct path" 

makedirs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

loader = data_loader(args)

with torch.set_grad_enabled(False):
    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    state_dict = torch.load(args.weights, map_location=device)
    unet.load_state_dict(state_dict)
    unet.eval()
    unet.to(device)

    input_list = []
    pred_list = []
    true_list = []

    for i, data in enumerate(loader):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)

        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

volumes = postprocess_per_volume(
    input_list,
    pred_list,
    true_list,
    loader.dataset.patient_slice_index,
    loader.dataset.patients,
)

dsc_dist = dsc_distribution(volumes)

dsc_dist_plot = plot_dsc(dsc_dist)
imsave(args.figure, dsc_dist_plot)

for p in volumes:
    x = volumes[p][0]
    y_pred = volumes[p][1]
    y_true = volumes[p][2]
    for s in range(x.shape[0]):
        image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
        image = outline(image, y_pred[s, 0], color=[255, 0, 0])
        image = outline(image, y_true[s, 0], color=[0, 255, 0])
        filename = "{}-{}.png".format(p, str(s).zfill(2))
        filepath = os.path.join(args.predictions, filename)
        dirpath = os.path.dirname(filepath)
        os.makedirs(dirpath, exist_ok=True)
        imsave(filepath, image)

reading validation images...
Load dataset from cache: .cache\validation.pkl
done creating validation dataset


RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "downs.0.0.weight", "downs.0.1.weight", "downs.0.1.bias", "downs.0.1.running_mean", "downs.0.1.running_var", "downs.0.3.weight", "downs.0.4.weight", "downs.0.4.bias", "downs.0.4.running_mean", "downs.0.4.running_var", "downs.1.0.weight", "downs.1.1.weight", "downs.1.1.bias", "downs.1.1.running_mean", "downs.1.1.running_var", "downs.1.3.weight", "downs.1.4.weight", "downs.1.4.bias", "downs.1.4.running_mean", "downs.1.4.running_var", "downs.2.0.weight", "downs.2.1.weight", "downs.2.1.bias", "downs.2.1.running_mean", "downs.2.1.running_var", "downs.2.3.weight", "downs.2.4.weight", "downs.2.4.bias", "downs.2.4.running_mean", "downs.2.4.running_var", "downs.3.0.weight", "downs.3.1.weight", "downs.3.1.bias", "downs.3.1.running_mean", "downs.3.1.running_var", "downs.3.3.weight", "downs.3.4.weight", "downs.3.4.bias", "downs.3.4.running_mean", "downs.3.4.running_var", "ups.0.weight", "ups.0.bias", "ups.1.weight", "ups.1.bias", "ups.2.weight", "ups.2.bias", "ups.3.weight", "ups.3.bias", "up_convs.0.0.weight", "up_convs.0.1.weight", "up_convs.0.1.bias", "up_convs.0.1.running_mean", "up_convs.0.1.running_var", "up_convs.0.3.weight", "up_convs.0.4.weight", "up_convs.0.4.bias", "up_convs.0.4.running_mean", "up_convs.0.4.running_var", "up_convs.1.0.weight", "up_convs.1.1.weight", "up_convs.1.1.bias", "up_convs.1.1.running_mean", "up_convs.1.1.running_var", "up_convs.1.3.weight", "up_convs.1.4.weight", "up_convs.1.4.bias", "up_convs.1.4.running_mean", "up_convs.1.4.running_var", "up_convs.2.0.weight", "up_convs.2.1.weight", "up_convs.2.1.bias", "up_convs.2.1.running_mean", "up_convs.2.1.running_var", "up_convs.2.3.weight", "up_convs.2.4.weight", "up_convs.2.4.bias", "up_convs.2.4.running_mean", "up_convs.2.4.running_var", "up_convs.3.0.weight", "up_convs.3.1.weight", "up_convs.3.1.bias", "up_convs.3.1.running_mean", "up_convs.3.1.running_var", "up_convs.3.3.weight", "up_convs.3.4.weight", "up_convs.3.4.bias", "up_convs.3.4.running_mean", "up_convs.3.4.running_var", "final.weight", "final.bias". 
	Unexpected key(s) in state_dict: "encoder.0.0.weight", "encoder.0.1.weight", "encoder.0.1.bias", "encoder.0.1.running_mean", "encoder.0.1.running_var", "encoder.0.1.num_batches_tracked", "encoder.0.3.weight", "encoder.0.4.weight", "encoder.0.4.bias", "encoder.0.4.running_mean", "encoder.0.4.running_var", "encoder.0.4.num_batches_tracked", "encoder.1.0.weight", "encoder.1.1.weight", "encoder.1.1.bias", "encoder.1.1.running_mean", "encoder.1.1.running_var", "encoder.1.1.num_batches_tracked", "encoder.1.3.weight", "encoder.1.4.weight", "encoder.1.4.bias", "encoder.1.4.running_mean", "encoder.1.4.running_var", "encoder.1.4.num_batches_tracked", "encoder.2.0.weight", "encoder.2.1.weight", "encoder.2.1.bias", "encoder.2.1.running_mean", "encoder.2.1.running_var", "encoder.2.1.num_batches_tracked", "encoder.2.3.weight", "encoder.2.4.weight", "encoder.2.4.bias", "encoder.2.4.running_mean", "encoder.2.4.running_var", "encoder.2.4.num_batches_tracked", "encoder.3.0.weight", "encoder.3.1.weight", "encoder.3.1.bias", "encoder.3.1.running_mean", "encoder.3.1.running_var", "encoder.3.1.num_batches_tracked", "encoder.3.3.weight", "encoder.3.4.weight", "encoder.3.4.bias", "encoder.3.4.running_mean", "encoder.3.4.running_var", "encoder.3.4.num_batches_tracked", "decoder.0.weight", "decoder.0.bias", "decoder.1.0.weight", "decoder.1.1.weight", "decoder.1.1.bias", "decoder.1.1.running_mean", "decoder.1.1.running_var", "decoder.1.1.num_batches_tracked", "decoder.1.3.weight", "decoder.1.4.weight", "decoder.1.4.bias", "decoder.1.4.running_mean", "decoder.1.4.running_var", "decoder.1.4.num_batches_tracked", "decoder.2.weight", "decoder.2.bias", "decoder.3.0.weight", "decoder.3.1.weight", "decoder.3.1.bias", "decoder.3.1.running_mean", "decoder.3.1.running_var", "decoder.3.1.num_batches_tracked", "decoder.3.3.weight", "decoder.3.4.weight", "decoder.3.4.bias", "decoder.3.4.running_mean", "decoder.3.4.running_var", "decoder.3.4.num_batches_tracked", "decoder.4.weight", "decoder.4.bias", "decoder.5.0.weight", "decoder.5.1.weight", "decoder.5.1.bias", "decoder.5.1.running_mean", "decoder.5.1.running_var", "decoder.5.1.num_batches_tracked", "decoder.5.3.weight", "decoder.5.4.weight", "decoder.5.4.bias", "decoder.5.4.running_mean", "decoder.5.4.running_var", "decoder.5.4.num_batches_tracked", "decoder.6.weight", "decoder.6.bias", "decoder.7.0.weight", "decoder.7.1.weight", "decoder.7.1.bias", "decoder.7.1.running_mean", "decoder.7.1.running_var", "decoder.7.1.num_batches_tracked", "decoder.7.3.weight", "decoder.7.4.weight", "decoder.7.4.bias", "decoder.7.4.running_mean", "decoder.7.4.running_var", "decoder.7.4.num_batches_tracked", "final_conv.weight", "final_conv.bias". 