In [None]:
%matplotlib widget
from matplotlib import pyplot as plt

In [None]:
import os

os.chdir("../..")
import time

import numpy as np
from sklearn import metrics

from polygeist.label import process_files_and_folders
from polygeist.training import train_model
from polygeist.utils import (
    SegmentationFilesDirectoryHandler,
    load_filenames_and_generate_conditions,
)
from polygeist.validation import validate

# Introduction

This workbook will process the Tau SVS slide files, producing regions of interest (ROIs) as jpegs for classification using the PDNet CNN.  Because Tau pathology exists in controls, we will be principally be discrimination between morphology; that is, Tau shape, colour and size in the segmented regions in control and pathology cases.

We will:

- Search through the data directory with `polygeist.label`, which will process our SVS files and produce ROIs for classification.
- Segment those ROIs into training and test sets for use with PDNet
- Run the PDNet training routine on those training images
- Validate the model by loading it and running it on the validate dataset
- Use `sklearn.metrics` to evaluate our model, and plot an ROC function.

In [None]:
# This configuration defines what we need to know about the slides and where to put the outputs, we also place the protein specific information
# for Tau here.
config = {
    # This is where the SVS slides are stored
    "svs_data_location": "/home/brad/localnas/",
    # This is the directory where all our segmentations, model files and sets will be stored
    "working_root": "/run/media/brad/ScratchM2/Tau_label_dump_256/",
    # These are our case filenames, which we shall parse to ensure case level segmenting in training and test
    "case_files": "Data/filenames/tau_files.txt",
    # Segmentation Specific Information
    # This is the stride over which we will look (the window size)
    "stride": 256,
    # The PUK set contains ID-INDEX_Protein in the filename, so here we specify 10_Tau (only use slide 10)
    "index": "10_Tau",
    # This is the threshold under which a DAB activation will be considered noise
    "raw_threshold": -0.1,
    # This is the amount of pixels (as a percentage) per region that have to be activated to define a ROI
    "class_threshold": 0.025,
    # PDNET Configuration
    "batch_size": 32,  # Adjust for memory constraints (may affect results)
    "num_epochs": 500,  # Adjust for time available for training (may affect results)
}

# Spectral Estimation & Segmentation

This process utility (from `polygeist.label`) takes a list of protein specific parameters and performs the spectral estimation technique to produce estimates of the DAB staining.  It then segments ROIs with specific parameters and dumps them to disk as jpegs for later use.

In [None]:
process_files_and_folders(
    # The input data folder, this is where the SVS files are located
    config["svs_data_location"],
    # Where we would like to dump the segmentations, and json files
    config["working_root"],
    # This is the stride over which we will look (the window size)
    stride=config["stride"],
    # This is the threshold under which a DAB activation will be considered noise
    raw_threshold=config["raw_threshold"],
    # This is the amount of pixels (as a percentage) per region that have to be activated to define a ROI
    class_threshold=config["class_threshold"],
    # Do not output full res density images
    output_density=False,
    # Output json metadata & density information
    output_json=True,
    # Skip outputting whole JPEGs
    skip_jpeg=True,
    # Automatically remove the slide background (note this is specialised to PUK Brain Slide Protocol)
    auto_remove_background=True,
    # Include only slides with 10_Tau in their name, this is slide index 10, Tau labelling
    include_only_index=config["index"],
    # Output each ROI as a JPEG for CNN training (and obs)
    output_segmentation=True,
    # Please provide print feedback on processing
    verbose=True,
)

In [None]:
# Get all the cases and our conditions for each
case_conditions = load_filenames_and_generate_conditions(config["case_files"])

In [None]:
# Uniformly split conditions
def split_cases_into_train_and_test(case_cond, condition):
    train = []
    test = []
    switch = False
    for key, value in case_cond.items():
        if condition not in value:
            continue
        if switch:
            train.append(key)
        else:
            test.append(key)
        switch = not switch
    return train, test

In [None]:
pd_train, pd_test = split_cases_into_train_and_test(case_conditions, "PD")
con_train, con_test = split_cases_into_train_and_test(case_conditions, "C")

In [None]:
files_handler = SegmentationFilesDirectoryHandler(config["working_root"])

In [None]:
files_handler.make_train_and_validation_folders_for_conditions()

# Sorting Data into Training and Test

Here we sort all regions into either training or test sets, we balance by the N images in the PD condition.

In [None]:
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=pd_train, condition="PD", training=True
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=con_train, condition="Controls", training=True
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=pd_test, condition="PD", training=False
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=con_test, condition="Controls", training=False
)

In [None]:
# Our dump path for our model training run, model checkpoints will be saved here
model_dump_dir = f"{config['working_root']}/model_dump/"

In [None]:
# We will use a clean copy of the data for performance, repeatability and safety.
training_dump_path = files_handler.root + "/test_partitioned_data/"

# Model Training

The data layout is passed to the train_model utility to produce us a PDNet model.

In [None]:
# We don't inject into the validation set, that is kept clean for validation of the colourimetric segmentation.

# Start a timer
start_time = time.time()

latest_model_name = train_model(
    training_dump_path,
    model_dump_dir,
    config["batch_size"],
    config["num_epochs"],
    strict=False,
)

time_elapsed = time.time() - start_time
print(f"Training complete in {time_elapsed // 60}m {time_elapsed % 60}s")

# Model Validation

The data layout is passed to the validation utility to produce us validation scores.  Here we load up the last checkpoint file, I have left it hard code, so make sure you change the name to the model file that you have generated. 

In [None]:
latest_model_name = f"PDNET_checkpoint_490_03_06_05"
# Now we can run validation, on slide and case level
# latest_model_name will have our last model, or it maybe specified manually.
# E.g. model_file = f"{model_dump_dir}/PDNET_checkpoint_490_16_18_48"
model_file = f"{model_dump_dir}/{latest_model_name}"

In [None]:
output_data_and_labels = validate(model_file, training_dump_path, config["batch_size"])

In [None]:
outputs = np.hstack(output_data_and_labels["outputs"])
labels = np.hstack(output_data_and_labels["labels"])

matched = outputs[labels == 1.0]
non_matched = outputs[labels == 0]

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(labels, outputs)

In [None]:
plt.figure()
plt.plot(fpr, tpr, label="Taupathology vs Control")
plt.legend()
plt.xlabel("False Alarm Rate", fontsize=18)
plt.ylabel("Hit Rate", fontsize=18)
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
# plt.title("512um Patch Level Discrmination between Taupathology and Control Tau Segmentation")
plt.show()

In [None]:
# Set an index for the threshold
th = 60
print(f"Threshold = {thresholds[th]}, TP : {tpr[th]}, FP {fpr[th]}")

In [None]:
# compute the confusion matrix
t = thresholds[th]
N_0 = len(outputs[labels == 0])
N_1 = len(outputs[labels == 1])
conf = [
    (np.sum(outputs[labels == 0] < t) / N_0, np.sum(outputs[labels == 0] >= t) / N_0),
    (np.sum(outputs[labels == 1] < t) / N_1, np.sum(outputs[labels == 1] >= t) / N_1),
]

In [None]:
# Confusion matrix
print("".ljust(10), "Control".ljust(10), "Path".ljust(10))
print("Control".rjust(10), f"{conf[0][0]:.4f}".ljust(10), f"{conf[0][1]:.4f}".ljust(10))
print("Path".rjust(10), f"{conf[1][0]:.4f}".ljust(10), f"{conf[1][1]:.4f}".ljust(10))