# Out-of-Distribution Detection for Model Refinement in Cardiac Image Segmentation

In [None]:
!nvidia-smi -L

In [None]:
import random
import numpy as np
import torch

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Install requirements

#### megatools

In [None]:
!sudo apt-get install megatools

#### batchgenerators

In [None]:
!pip install --upgrade batchgenerators

#### medpy

In [None]:
!pip install medpy

#### nibabel

In [None]:
!pip install nibabel

## Import data

In [None]:
!git clone https://github.com/manigalati/MnMs2

In [None]:
%cd /content/MnMs2

In [None]:
!megadl "https://mega.nz/#!RERgEZhK!C9SXLaRV6Ep3pRn5ips_rGLsHUxP9wDZeJlBOSY6Af4"

In [None]:
!mkdir data
#MnMs2 dataset
!python gdrivedl.py https://drive.google.com/open?id=1-18mR3kwzErkqeTDZ-_UzQtalvnYK1fD data/
#vendor information
!python gdrivedl.py https://drive.google.com/open?id=1B8Rv0xGExxEahnYE7LMx0MeiJVXdaDPM data/

In [None]:
!unrar e data/MnM-2.rar data/
!rm data/MnM-2.rar

## Data preparation

In [None]:
import os
from utils import get_vendor_info
from utils import get_splits
from utils import generate_patient_info, crop_image
from utils import preprocess, preprocess_image, inSplit

The challenge cohort was composed of 360 patients with different rigth ventricle and left ventricle pathologies as well as healthy subjects. All subjects were scanned in four clinical centres in two different countries (Spain, Germany) using four different magnetic resonance scanner vendors (Siemens, General Electric and Philips ).

The training set contained 200 annotated images from four different centres. The CMR images have been segmented by experienced clinicians from the respective institutions, including contours for the left (LV) and right ventricle (RV) blood pools, as well as for the left ventricular myocardium (MYO). Labels are: 1 (LV), 2 (MYO) and 3 (RV) in both short-axis and long-axis views with a variety of difficult RV pathologies and remodelling as well as LV pathologies. This year we will focus on RV segmentation. Labels 1 and 2 will be provided but will not score in the final challenge results. 40 cases, 5 for each pathology, will be used to create a public leaderboard and will be added to the final testing set. Two pathologies (Tricuspidal Regurgitation and Congenital Arrhythmogenesis) will be not present in the training set but in the validation and testing sets to evaluate generalisation to unseen pathologies.

<table style="width:70%; margin: auto">
        <tbody><tr>
            <th>Pathology</th><th>Num. studies training</th> <th>Num. studies validation</th>
        </tr>
        <tr>
            <td style="text-align:left">Normal subjects</td><td>40</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Dilated Left Ventricle</td><td>30</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Hypertrophic Cardiomyopathy</td><td>30</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Congenital Arrhythmogenesis</td><td>20</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Tetralogy of Fallot</td><td>20</td><td>5</td>
        </tr>
        <tr>      
            <td style="text-align:left">Interatrial Comunication</td><td>20</td><td>5</td>
        </tr>
        <tr>       
            <td style="text-align:left">Dilated Right Ventricle</td><td>0</td><td>5</td>
        </tr>
        <tr>
            <td style="text-align:left">Tricuspidal Regurgitation</td><td>0</td><td>5</td>
        </tr>
</tbody></table>


In [None]:
vendor_info = get_vendor_info("data/MNM2_vendor_information_training_set.csv")
vendor_info

In [None]:
vendor_info[['VENDOR','SUBJECT_CODE']].groupby(['VENDOR']).count().rename(columns={'SUBJECT_CODE':'NUM_STUDIES'})

In [None]:
if not os.path.isdir("preprocessed"):
    os.makedirs("preprocessed")
splits = get_splits(os.path.join("preprocessed", "splits.pkl"))

In [None]:
patient_info = generate_patient_info(vendor_info, os.path.join("preprocessed", "patient_info.pkl"))

In [None]:
spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]
spacing_target = np.percentile(np.vstack(spacings), 50, 0)

if not os.path.isdir("preprocessed/training/labelled"): os.makedirs("preprocessed/training/labelled")
if not os.path.isdir("preprocessed/training/unlabelled"): os.makedirs("preprocessed/training/unlabelled")
if not os.path.isdir("preprocessed/validation/"): os.makedirs("preprocessed/validation/")
if not os.path.isdir("preprocessed/soft_validation/"): os.makedirs("preprocessed/soft_validation/")
if not os.path.isdir("preprocessed/testing/"): os.makedirs("preprocessed/testing/")

preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["train"]["lab"])},
    spacing_target, "preprocessed/training/labelled/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["train"]["ulab"])},
    spacing_target, "preprocessed/training/unlabelled/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "preprocessed/validation/"
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["val"])},
    spacing_target, "preprocessed/soft_validation/", soft_preprocessing=True
)
preprocess(
    {k:v for k,v in patient_info.items() if inSplit(k, splits["test"])},
    spacing_target, "preprocessed/testing/", soft_preprocessing=True
)

## $\mathcal{M}$ - Supervised Training

In [None]:
import torch.nn as nn
import os
import torch

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from utils import AttrDict
from utils import GDLoss, CELoss
from utils import device
from utils import Validator, Checkpointer
from utils import supervised_training
from utils import ACDCDataLoader, ACDCAllPatients
from utils import BATCH_SIZE, EPOCHS
from utils import transform_augmentation_downsample, transform
from utils import plot_history

In [None]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ckpts = None
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        model[axis].load_state_dict(ckpt["M"])
        model[axis].optimizer.load_state_dict(ckpt["M_optim"])
else:
    start = 1
print(model)

In [None]:
validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    supervised_training(
        model[axis],
        range(start, 5), #EPOCHS
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform_augmentation_downsample
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        validators[axis],
        Checkpointer(os.path.join("checkpoints/M", axis))
    )

    plot_history(validators[axis].get_history("val"))

## $\mathcal{M}$ - Testing

In [None]:
import torch.nn as nn
import os
import torch
import pickle
import numpy as np

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from utils import device
from utils import ACDCDataLoader
from utils import BATCH_SIZE
from utils import transform
from utils import infer_predictions
from utils import get_splits
from utils import postprocess_predictions, display_results

In [None]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model()] for axis in ["SA", "LA"]
]).to(device)

for axis in ["SA", "LA"]:
    ckpt = os.path.join("checkpoints/M", axis, "best_000.pth")
    model[axis].load_state_dict(torch.load(ckpt)["M"])
    model[axis].to(device)
    model.eval()

    infer_predictions(
        os.path.join("inference", axis),
        ACDCDataLoader(
            f"preprocessed/testing/{axis}",
            batch_size = BATCH_SIZE,
            transform = transform,
            transform_gt = False
        ),
        model[axis]
    )

In [None]:
splits = get_splits(os.path.join("checkpoints", "splits.pkl"))

with open(os.path.join("preprocessed", "patient_info.pkl"),'rb') as f:
    patient_info = pickle.load(f)

spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]

current_spacing = np.percentile(np.vstack(spacings), 50, 0)

In [None]:
results = {}
for axis in ["SA", "LA"]:
    results[axis] = postprocess_predictions(
        os.path.join("inference", axis),
        patient_info,
        current_spacing,
        os.path.join("postprocessed", axis),
    )

display_results(results)

## $\mathcal{R}$ - Training

In [None]:
import torch.nn as nn
import torch
import os

from reconstructor import Reconstructor
from utils import AttrDict
from utils import device
from utils import plot_history
from utils import ACDCAllPatients, ACDCDataLoader
from utils import transform
from utils import BATCH_SIZE

In [None]:
ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ckpts = None
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        ae[axis].load_state_dict(ckpt["R"])
        ae[axis].optimizer.load_state_dict(ckpt["R_optim"])
else:
    start = 0
print(ae)

In [None]:
for axis in ["SA", "LA"]:
    plot_history(ae[axis].training_routine(
        range(start, 5),#500
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        os.path.join("checkpoints/R/", axis)
    ))

## QC-based Candidate Selection

In [None]:
import torch.nn as nn
import os
import torch
import pickle
import numpy as np

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from reconstructor import Reconstructor
from utils import device
from utils import AttrDict
from utils import Validator
from utils import ACDCDataLoader
from utils import BATCH_SIZE
from utils import transform
from utils import GDLoss, CELoss, GDLoss_RV, CELoss_RV
from utils import infer_predictions
from utils import get_splits
from utils import postprocess_predictions
from utils import display_results

In [None]:
Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss],
            "functions_RV": [GDLoss_RV, CELoss_RV]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    ckpt = os.path.join("checkpoints/R", axis)
    ckpt = os.path.join(ckpt, sorted([file for file in os.listdir(ckpt) if "_best" in file])[-1])
    ckpt = torch.load(ckpt)
    ae[axis].load_state_dict(ckpt["R"])
    ae.eval()
    
    for file in os.listdir(os.path.join("checkpoints/M", axis)):
        if "best_" not in file:
            continue
        ckpt = torch.load(os.path.join("checkpoints/M", axis, file))
        model[axis].load_state_dict(ckpt["M"])
        model.eval()
        with torch.no_grad():
            validators[axis].domain_evaluation(
                "test",
                model[axis],
                ACDCDataLoader(
                    f"preprocessed/testing/{axis}",
                    batch_size = BATCH_SIZE,
                    transform = transform,
                    transform_gt = False
                ),    
                reconstructor=ae[axis]
            )


In [None]:
for axis in ["SA", "LA"]:
    infer_predictions(
        os.path.join("inference", axis),
        ACDCDataLoader(
            f"preprocessed/testing/{axis}",
            batch_size = BATCH_SIZE,
            transform = transform,
            transform_gt = False
        ),
        validator=validators[axis]
    )

In [None]:
splits = get_splits(os.path.join("checkpoints", "splits.pkl"))

with open(os.path.join("preprocessed", "patient_info.pkl"),'rb') as f:
    patient_info = pickle.load(f)

spacings = [
    patient_info["{:03d}_{}".format(id, axis)]["spacing"] for axis in ["SA", "LA"] for id in (
        splits["train"]["lab"] + splits["train"]["ulab"] + splits["val"]
    )
]

current_spacing = np.percentile(np.vstack(spacings), 50, 0)

In [None]:
results = {}
for axis in ["SA", "LA"]:
    results[axis] = postprocess_predictions(
        os.path.join("inference", axis),
        patient_info,
        current_spacing,
        os.path.join("postprocessed", axis),
    )

display_results(results)

## Semi-Supervised Refinement

In [None]:
import torch.nn as nn
import os
import torch

from baseline_1 import Baseline_1
from baseline_2 import Baseline_2
from reconstructor import Reconstructor
from utils import AttrDict
from utils import GDLoss, CELoss, GDLoss_RV, CELoss_RV
from utils import device
from utils import Validator, Checkpointer
from utils import semisupervised_refinement
from utils import ACDCDataLoader, ACDCAllPatients
from utils import BATCH_SIZE, EPOCHS
from utils import transform_augmentation_downsample, transform
from utils import plot_history

In [None]:
#TODO: checkpoint should save also validator informations

Model = Baseline_2

model = nn.ModuleDict([
    [axis, Model(
        AttrDict(**{
            "lr": 0.01,
            "functions": [GDLoss, CELoss],
            "functions_RV": [GDLoss_RV, CELoss_RV]
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ae = nn.ModuleDict([
    [axis, Reconstructor(
        AttrDict(**{
            "latent_size": 100,
            "lr": 2e-4,
            "last_layer": [4,2,1],
            "in_channels": 4,
            "weighted_epochs": 0
        })
    )] for axis in ["SA", "LA"]
]).to(device)

ckpts = None#TODO: here ckpts should be compulsory, also with the autoencoder
if ckpts is not None:
    for axis, ckpt in ckpts.items():
        _, start = os.path.split(ckpt)
        start = int(start.replace(".pth", ""))
        ckpt = torch.load(ckpt)
        model[axis].load_state_dict(ckpt["M"])
        model[axis].optimizer.load_state_dict(ckpt["M_optim"])
else:
    start = 1
print(model)

In [None]:
validators = {
    "SA": Validator(5),
    "LA": Validator(5)
}

for axis in ["SA", "LA"]:
    semisupervised_refinement(
        model[axis],
        ae[axis],
        range(start, 5),#200
        torch.utils.data.DataLoader(
            ACDCAllPatients(
                os.path.join("preprocessed/training/labelled/", axis),
                transform=transform_augmentation_downsample
            ),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=0
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/validation/", axis),
            batch_size=BATCH_SIZE, transform=transform
        ),
        ACDCDataLoader(
            os.path.join("preprocessed/training/unlabelled/", axis),
            batch_size = BATCH_SIZE, transform=transform
        ),
        validators[axis],
        Checkpointer(os.path.join("checkpoints/M_refinement", axis))
    )

    plot_history(validators[axis].get_history("val"))