In [None]:
%matplotlib widget
from matplotlib import pyplot as plt

In [None]:
import os

os.chdir("../..")
import fnmatch
import random
import time
from glob import glob

import imageio as io
import numpy as np
from sklearn import metrics

from polygeist.label import process_files_and_folders
from polygeist.slidecore.slide import AperioSlide as Slide
from polygeist.slidecore.slide import SpectralSlideGenerator
from polygeist.training import train_model
from polygeist.utils import (
    SegmentationFilesDirectoryHandler,
    get_case_and_slide,
    load_filenames_and_generate_conditions,
)
from polygeist.validation import validate

# Overview

This workbook produces a classifier that can detect a-syn from colourimetric segmentations.  It will either segment the SVS slide files, using the spectral decomposition technique described in WP1, or generate some random noise images if simulated data is checked.  The dumped files will then be sorted into train and test, and processed using PDNet.

## Configuration

The configuration here specifies where the SVS files are, a 'working root' where we will dump the images, and protein specific configurations.  If you are using synthethic data, the case_files will not be used, but we will use the case identifiers for consistency.

In [None]:
config = {
    # This is where the SVS slides are stored
    "svs_data_location": "/home/brad/localnas/",
    # This is the directory where all our segmentations, model files and sets will be stored
    "working_root": "/run/media/brad/ScratchM2/asyn_512_wp2_run/",  # "/run/media/brad/ScratchM2/test_dump/",#
    # These are our case filenames, which we shall parse to ensure case level segmenting in training and test
    "case_files": "Data/filenames/asyn_files.txt",
    # Segmentation Specific Information
    # This is the stride over which we will look (the window size)
    "stride": 512,
    # The PUK set contains ID-INDEX_Protein in the filename, so here we specify 17_A (DMNoV, slide 17)
    "index": "17_A-syn",
    # This is the threshold under which a DAB activation will be considered noise
    "raw_threshold": -0.3,
    # This is the amount of pixels (as a percentage) per region that have to be activated to define a ROI
    "class_threshold": 0.00125,
    # PDNET Configuration
    "batch_size": 16,  # Adjust for memory constraints (may affect results)
    "num_epochs": 500,  # Adjust for time available for training (may affect results)
    # Toggle this parameter if you have genearted simulated slides.
    "simulated_data": False,
}

In [None]:
# Our dump path for our model training run, model checkpoints will be saved here
model_dump_dir = f"{config['working_root']}/model_dump/"

In [None]:
# Get all the cases and our conditions for each
case_conditions = load_filenames_and_generate_conditions(config["case_files"])

In [None]:
# Uniformly split conditions
def split_cases_into_train_and_test(case_cond, condition):
    train = []
    test = []
    switch = False
    for key, value in case_cond.items():
        if condition not in value:
            continue
        if switch:
            train.append(key)
        else:
            test.append(key)
        switch = not switch
    return train, test

## Splitting Cases (Brains) into Training and Test

Here we split the cases into groups for training and test, then if we are using simulated data, we will just simulate a bunch of random regions to pass to PDNet.

In [None]:
pd_train, pd_test = split_cases_into_train_and_test(case_conditions, "PD")
con_train, con_test = split_cases_into_train_and_test(case_conditions, "C")

In [None]:
if config["simulated_data"]:
    for subset in [pd_train, pd_test, con_train, con_test]:
        for case in subset:
            for n in np.arange(0, np.random.randint(10)):
                filename = f"{config['working_root']}/{case}-17_A-syn.svs{n}.jpg"
                print(f"Generating and writing random image to {filename}")
                SpectralSlideGenerator(
                    width=config["stride"], height=config["stride"], filename=filename
                )

In [None]:
# This is belt and braces to ensure that we do not have a set intersection.
if len(list(set(pd_train) & set(pd_test))) > 0:
    print("There is an overlap between training and test images for PD")
if len(list(set(con_train) & set(con_test))) > 0:
    print("There is an overlap between training and test images for Control")

# Segmentation

The 'process_files_and_folders' routine is the main segmentaion procedure.  This will load and spectrally decompose our data into DAB channels, and then dump each ROI identified into the working directory.  We skip this for simulated data, as all our regions are already random data.  Should you wish to run this, see the WP1 workbook.

This routine will report the status of each segmented slide.  It can take many hours to complete if every slide is being processed.

In [None]:
if not config["simulated_data"]:
    process_files_and_folders(
        # The input data folder, this is where the SVS files are located
        config["svs_data_location"],
        # Where we would like to dump the segmentations, and json files
        config["working_root"],
        # This is the stride over which we will look (the window size)
        stride=config["stride"],
        # This is the threshold under which a DAB activation will be considered noise
        raw_threshold=config["raw_threshold"],
        # This is the amount of pixels (as a percentage) per region that have to be activated to define a ROI
        class_threshold=config["class_threshold"],
        # Do not output full res density images
        output_density=False,
        # Output json metadata & density information
        output_json=True,
        # Skip outputting whole JPEGs
        skip_jpeg=True,
        # Automatically remove the slide background (note this is specialised to PUK Brain Slide Protocol)
        auto_remove_background=True,
        # Include all slides, but only A-beta stain.
        include_only_index=config["index"],
        # Output each ROI as a JPEG for CNN training (and obs)
        output_segmentation=True,
        # Please provide print feedback on processing
        verbose=True,
        # Toggle this flag if you are using synthetic data.  Note, your root should be full of synthetic jpegs
        synthetic=config["simulated"],
    )

## Splits

The file handler will split the files into a 'train' and 'val' directory and further into group folders for the torch Dataset handlers.

In [None]:
files_handler = SegmentationFilesDirectoryHandler(config["working_root"])

In [None]:
files_handler.make_train_and_validation_folders_for_conditions()

In [None]:
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=pd_train,
    condition="PD",
    training=True,
    slide_index_filter=[17],
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=con_train,
    condition="Controls",
    training=True,
    slide_index_filter=[17],
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=pd_test,
    condition="PD",
    training=False,
    slide_index_filter=[17],
)
files_handler.split_and_copy_root_data_to_train_and_validation(
    case_filter_for_train=con_test,
    condition="Controls",
    training=False,
    slide_index_filter=[17],
)

In [None]:
training_dump_path = config["working_root"] + "/partitioned_data"

## Balancing

This routine will sample the raw slides to find more control regions, should the number of control samples be low.  This is only relevant in the case of real data, so it is skipped for synthetic data.  The segmentation procedure is very good at not producing ROIs for the control group, and therefore there would be a data inbalance during training, so by randomly sampling we can preserve the data balance.

In [None]:
# Now we are going to balance the training and test sets, to prevent overfitting, we do this by randomly sampling new
# regions from the control set images until the number of control squares is the same of the test squares.
# !! NOTE !! This only works when set A > set B.  If set B was larger, we would have to get more examples from set A
# where a random sample would not work.
l_control = len(glob(f"{training_dump_path}/train/Controls/*.jpg"))
l_test = len(glob(f"{training_dump_path}/train/PD/*.jpg"))

if not config["simulated_data"]:
    # This function will traverse our raw slides directory and gather candidate files that match our
    # criteria of slide 17, and in the valid set
    def get_valid_control_file_list(valid_cases):
        matches = []
        for root, dirnames, filenames in os.walk(config["svs_data_location"]):
            for filename in fnmatch.filter(filenames, "*.svs"):
                case, slide = get_case_and_slide(filename)
                if case not in case_conditions:
                    continue
                if case in valid_cases and slide == 17:
                    matches.append(os.path.join(root, filename))
        return matches

    control_file_list = get_valid_control_file_list(con_train)
    # Continue while set is unbalanced ~100
    control_injection_index = 0
    while l_test > l_control:
        # Randomise the list
        random.shuffle(control_file_list)

        # Sample the top of the list '0th' element is the top which
        # has just been shuffled
        slide = Slide(control_file_list[0]).get_slide_with_pixel_resolution_in_microns(
            2.0
        )
        filename = os.path.basename(control_file_list[0])

        yy, xx, _ = slide.shape

        # Create a densities array to store the local densities
        x_pass = int(np.ceil(xx / 512))
        y_pass = int(np.ceil(yy / 512))

        # Make sure we are well within the slide tissue
        for x, y in zip(
            np.random.randint(4, x_pass - 4, 25), np.random.randint(4, y_pass - 4, 25)
        ):
            im = slide[
                (y * 512) : (y * 512) + 512, (x * 512) : (x * 512) + 512, :
            ].copy()
            io.imwrite(
                f"{training_dump_path}/train/Controls/CI_{filename}_{control_injection_index}.jpg",
                im,
            )
            control_injection_index += 1

        # Recount
        l_control = len(glob(f"{training_dump_path}/train/Controls/*.jpg"))
        l_test = len(glob(f"{training_dump_path}/train/PD/*.jpg"))

## Training

We now pass all our parameters such as where the images are dumped, and the model directory where we will put our epoch checkpoint files, to the train_model routine.

In [None]:
# We don't inject into the validation set, that is kept clean for validation of the colourimetric segmentation.

# Start a timer
start_time = time.time()

latest_model_name = train_model(
    training_dump_path,
    model_dump_dir,
    config["batch_size"],
    config["num_epochs"],
    strict=False,
)

time_elapsed = time.time() - start_time
print(f"Training complete in {time_elapsed // 60}m {time_elapsed % 60}s")

## Validation

Here we load up the last checkpoint file, I have left it hard code, so make sure you change the name to the model file that you have generated. 

Then we will pass our new model filename and the training path to our validation routine, which will run a pass on the 'val' folder.

In [None]:
## I am renaming the model name here as I am running this later, but you comment this out otherwise.
latest_model_name = f"PDNET_checkpoint_70_11_05_12"
# Now we can run validation, on slide and case level
# latest_model_name will have our last model, or it maybe specified manually.
# E.g. model_file = f"{model_dump_dir}/PDNET_checkpoint_490_16_18_48"
model_file = f"{model_dump_dir}/{latest_model_name}"

In [None]:
output_data_and_labels = validate(model_file, training_dump_path, config["batch_size"])

In [None]:
outputs = np.hstack(output_data_and_labels["outputs"])
labels = np.hstack(output_data_and_labels["labels"])

matched = outputs[labels == 1.0]
non_matched = outputs[labels == 0]

In [None]:
# Exclude the excluded cases (label 2 is the 'EXCLUDE' folder), if there is no EXCLUDE folder, this does nothing.
outputs = outputs[labels < 2]
labels = labels[labels < 2]

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(labels, outputs, drop_intermediate=False)

In [None]:
plt.figure()
plt.plot(fpr, tpr, label="Asyn Pathology vs Control")
plt.legend()
plt.xlabel("False Alarm Rate", fontsize=18)
plt.ylabel("Hit Rate", fontsize=18)
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
plt.show()

In [None]:
# Set an index for the threshold
th = 70
print(f"Threshold = {thresholds[th]}, TP : {tpr[th]}, FP {fpr[th]}")

In [None]:
# compute the confusion matrix
t = thresholds[th]
N_0 = len(outputs[labels == 0])
N_1 = len(outputs[labels == 1])
conf = [
    (np.sum(outputs[labels == 0] < t) / N_0, np.sum(outputs[labels == 0] >= t) / N_0),
    (np.sum(outputs[labels == 1] < t) / N_1, np.sum(outputs[labels == 1] >= t) / N_1),
]

In [None]:
# Confusion matrix
print("".ljust(10), "Control".ljust(10), "Path".ljust(10))
print("Control".rjust(10), f"{conf[0][0]:.4f}".ljust(10), f"{conf[0][1]:.4f}".ljust(10))
print("Path".rjust(10), f"{conf[1][0]:.4f}".ljust(10), f"{conf[1][1]:.4f}".ljust(10))