In [None]:
import os
os.chdir('../..')
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from glob import glob
from polygeist.utils import load_filenames_and_generate_conditions
import json
from tqdm import tqdm
from skimage.transform import resize
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import random
from datetime import datetime
from sklearn.metrics import confusion_matrix

# Introduction

This workbook uses the density maps produced in WP2, and the staging information provided by ICL, to produce a classifier.  This classifier can perform binary and multi class classfication of pathology based on those denisty maps.

The steps here are to:

- Define a data handler that can abstract the density maps and classifications into tensors and labels for training and validation
- Define a small CNN that will be used during training, and any data transforms to be applied
- Perform training and validation

In [None]:
class BraakDataset(Dataset):
    def __init__(self, cases, root, spreadsheet='/home/brad/repos/ParkinsonsSyntheticStaining/Data/'
                                                'cases_for_staging_asyn_c319_1_sg.xlsx',
                 transform=None, slide_size=32, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                 reduce_classes=False, reduce_type='binary'):
        # Load Braak Stages
        braak_staging = pd.read_excel(spreadsheet, index_col=0)

        slide_labels = ['01_A', '02_A', '04_A', '15_A', '16_A', '17_A']

        self.data = []
        self.targets = []
        self.device = device

        for case in tqdm(cases, desc="Loading Cases"):
            density_map_array = np.zeros((slide_size, slide_size, len(slide_labels)))
            if case not in braak_staging.index:
                continue

            stage = braak_staging.loc[case]['STAGE']
            if reduce_classes:
                if 'binary' in reduce_type:
                    stage = stage > 0
                else:
                    stage = 0 if stage < 1 else (2 if stage == 6 else 1)

            files = glob(f"{root}/{case}*.json")
            for i, label in enumerate(slide_labels):
                find_file = next((item for item in [x if label in x else None for x in files] if item is not None), [])
                if find_file:
                    with open(find_file, 'r') as fp:
                        density_map = np.array(json.load(fp)['densities'])
                        resized_map = resize(density_map, (slide_size, slide_size))
                        density_map_array[:, :, i] = np.array(resized_map)

            density_map_array = np.moveaxis(density_map_array, -1, 0)
            self.data.append(density_map_array)
            self.targets.append(stage)

        self.data = torch.Tensor(self.data).to(self.device)
        self.targets = torch.Tensor(self.targets).to(self.device)
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)

In [None]:
# The simple CNN to be used during training
class SeqBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.c1 = nn.Sequential(nn.Conv2d(cin, cin, 4), nn.Conv2d(cin, cin, 3), nn.Conv2d(cin, cout, 2))
        self.pool = nn.MaxPool2d(2, 2)
    def forward(self, x):
        return self.pool(F.relu(self.c1(x)))


class CNN(nn.Module):
    def __init__(self, binary=False):
        super().__init__()
        
        self.conv1 = SeqBlock(6, 6)
        self.drop1 = nn.Dropout(p=0.15)
        self.conv2 = SeqBlock(6, 5)
        self.conv3 = SeqBlock(5, 2)
        self.linear = nn.Linear(8, 3 if not binary else 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.drop1(x)
        x = self.conv3(self.conv2(x))
        x = x.flatten(1)
        x = self.linear(x)
        x = torch.sigmoid(x)
        return x

# Loading and Processing the Dataset

Here we get the dataset and shuffle the case, then we sample the top n'th percent and use those during training.  We create our dataset objects for use in training and validation.

In [None]:
case_conditions = load_filenames_and_generate_conditions("Data/filenames/asyn_files.txt")

In [None]:
prop = .50
pd_set = []
con_set = []
for k, v in case_conditions.items():
    if 'PD' in v:
        pd_set.append(k)
    else:
        con_set.append(k)
random.shuffle(pd_set)
random.shuffle(con_set)
sample = int(len(con_set) * prop)

In [None]:
pd_train = pd_set[0:sample]
pd_test = pd_set[sample+1:]
con_train = con_set[0:sample]
con_test = con_set[sample+1:]

In [None]:
# TODO: Should these parameters be in a dict at the start of the notebook as in the previous ones?
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
slide_size = 64
reduce_classes = True

train_dataset = BraakDataset(pd_train+ con_train, "/run/media/brad/ScratchM2/asyn_rerun_128", transform=None,
                             slide_size=slide_size, reduce_classes=reduce_classes, reduce_type='binary')
train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True)

test_dataset = BraakDataset(pd_test + con_test, "/run/media/brad/ScratchM2/asyn_rerun_128", transform=None,
                            slide_size=slide_size, reduce_classes=reduce_classes, reduce_type='binary')
test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True)

# Training and Validation

Here we run training with a high epoch count (this is due to a small network and tiny dataset).  Then we load the state dictionary saved during the training step, and evaluate our test dataset.

In [None]:
model = CNN().to(dev)
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

t_loss_history = []
t_correct_history = []
v_correct_history = []
v_loss_history = []
model = model.train()

num_epochs = 10000
best_val = 0

for epoch in tqdm(np.arange(num_epochs)):
    sum_loss = 0.0
    correct = 0
    run_count = 0
    for i, dl in enumerate(train_dataloader, 0):
        data, labels = dl
        optimizer.zero_grad()
        outputs = model(data)
        classification = torch.sum(outputs * torch.arange(0, outputs.shape[1]).to(dev), 1)
        loss = criterion(classification, labels)
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
        correct += torch.sum(torch.round(classification) == labels) / len(labels)
        t_loss_history.append(sum_loss)
        t_correct_history.append(correct)
        run_count += 1
    if epoch % 250 == 0:
        print(f"Train -> Epoch {epoch} / {num_epochs} | Loss : {sum_loss}, Correct : {correct / run_count}")
    model.eval()
    sum_loss = 0.0
    correct = 0
    run_count = 0
    for i, dl in enumerate(test_dataloader, 0):
        data, labels = dl
        outputs = model(data)
        classification = torch.sum(outputs * torch.arange(0, outputs.shape[1]).to(dev), 1)
        loss = criterion(classification, labels)
        sum_loss += loss.item()
        correct += torch.sum(torch.round(classification) == labels) / len(labels)
        v_loss_history.append(sum_loss)
        v_correct_history.append(correct)
        run_count += 1
    if epoch % 250 == 0:
        print(f"Validation -> Epoch {epoch} / {num_epochs} | Loss : {sum_loss}, Correct : {correct / run_count}")
    if correct / run_count > best_val and num_epochs > 3:
        torch.save(model.state_dict(), f"mmclass_{datetime.now().strftime('%d_%m_%Y__%H_%M_%S')}_.pth")
        best_val = correct / run_count


In [None]:
# TODO: Would it make more sense to have a branch in the notebook for "train xor load" based on a config option?
model.load_state_dict(torch.load('mmclass_23_02_2023__16_09_21_.pth'))

`Train -> Epoch 6900 / 10000 | Loss : 0.04006555293199199, Correct : 0.9942857623100281`
`Validation -> Epoch 6900 / 10000 | Loss : 0.9827206871559611, Correct : 0.8920000195503235`

In [None]:
running_results = []
running_labels = []
for i, dl in enumerate(test_dataloader, 0):
    data, labels = dl
    outputs = model(data)
    classification = torch.sum(outputs * torch.arange(0, outputs.shape[1]).to(dev), 1)
    running_results.append(torch.round(classification).tolist())
    running_labels.append(labels.tolist())

In [None]:
conf_labels = np.hstack(running_labels)
conf_results = np.hstack(running_results)

In [None]:
confusion_mat = confusion_matrix(conf_labels, conf_results > 0)

In [None]:
confusion_mat

In [None]:
confusion_mat[0, :] / np.sum(confusion_mat[0, :])

In [None]:
confusion_mat[1, :] / np.sum(confusion_mat[1, :])