In [None]:
import os

os.chdir("../..")
import json
import random
from datetime import datetime
from glob import glob

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from skimage.transform import resize
from sklearn.metrics import confusion_matrix
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from polygeist.utils import load_filenames_and_generate_conditions

# Introduction

This workbook uses the density maps produced in WP2, and is designed to use all density maps from a case to perform case level classification.  This classifier can perform binary and multi class classfication of pathology based on those denisty maps.

The steps here are to:

- Define a data handler that can abstract the density maps and classifications into tensors and labels for training and validation
- Define a small CNN that will be used during training, and any data transforms to be applied
- Perform training and validation

In [None]:
class ABetaDataset(Dataset):
    def __init__(
        self,
        cases,
        root,
        transform=None,
        slide_size=32,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        case_conditions=None,
    ):

        slide_labels = ["1_A", "10_A", "14_A", "15_A", "09_A", "05_A"]

        self.data = []
        self.targets = []
        self.device = device

        for case in tqdm(cases, desc="Loading Cases"):
            density_map_array = np.zeros((slide_size, slide_size, len(slide_labels)))

            if case_conditions:
                condition = case_conditions[case]
                class_index = 1 if "AD" in condition else 0
            else:
                # TODO: Should this be case_index or class_index?
                case_index = 0 if "C" in case else 1

            files = glob(f"{root}/{case}*.json")
            for i, label in enumerate(slide_labels):
                find_file = next(
                    (
                        item
                        for item in [x if label in x else None for x in files]
                        if item is not None
                    ),
                    [],
                )
                if find_file:
                    with open(find_file, "r") as fp:
                        density_map = np.array(json.load(fp)["densities"])
                        resized_map = resize(density_map, (slide_size, slide_size))
                        density_map_array[:, :, i] = np.array(resized_map)

            density_map_array = np.moveaxis(density_map_array, -1, 0)
            self.data.append(density_map_array)
            self.targets.append(class_index)

        self.data = torch.Tensor(self.data).to(self.device)
        self.targets = torch.Tensor(self.targets).to(self.device)
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)

In [None]:
class SeqBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.c1 = nn.Sequential(
            nn.Conv2d(cin, cin, 4), nn.Conv2d(cin, cin, 3), nn.Conv2d(cin, cout, 2)
        )
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        return self.pool(F.relu(self.c1(x)))


class CNN(nn.Module):
    def __init__(self, binary=False):
        super().__init__()

        self.conv1 = SeqBlock(6, 6)
        self.drop1 = nn.Dropout(p=0.15)
        self.conv2 = SeqBlock(6, 5)
        self.conv3 = SeqBlock(5, 2)
        self.linear = nn.Linear(8, 3 if not binary else 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.drop1(x)
        x = self.conv3(self.conv2(x))
        x = x.flatten(1)
        x = self.linear(x)
        x = torch.sigmoid(x)
        return x

In [None]:
case_conditions = load_filenames_and_generate_conditions(
    "Data/filenames/abeta_files.txt"
)

In [None]:
ad_set = []
con_set = []
for k, v in case_conditions.items():
    if "AD" in v:
        ad_set.append(k)
    else:
        con_set.append(k)

In [None]:
prop = 0.60
random.shuffle(ad_set)
random.shuffle(con_set)
sample = int(len(ad_set) * prop)

In [None]:
ad_train = ad_set[0:sample]
ad_test = ad_set[sample + 1 :]
con_train = con_set[0:sample]
con_test = con_set[sample + 1 :]

In [None]:
ad_train

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
slide_size = 64
reduce_classes = True

train_dataset = ABetaDataset(
    ad_train + con_train,
    "/run/media/brad/ScratchM2/ABeta_label_dump_64",
    transform=None,
    slide_size=slide_size,
    case_conditions=case_conditions,
)

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True)

test_dataset = ABetaDataset(
    ad_test + con_test,
    "/run/media/brad/ScratchM2/ABeta_label_dump_64",
    transform=None,
    slide_size=slide_size,
    case_conditions=case_conditions,
)
test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True)

In [None]:
model = CNN().to(dev)
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

t_loss_history = []
t_correct_history = []
v_correct_history = []
v_loss_history = []
model = model.train()

num_epochs = 10000
best_val = 0

for epoch in tqdm(np.arange(num_epochs)):
    sum_loss = 0.0
    correct = 0
    run_count = 0
    for i, dl in enumerate(train_dataloader, 0):
        data, labels = dl
        optimizer.zero_grad()
        outputs = model(data)
        classification = torch.sum(
            outputs * torch.arange(0, outputs.shape[1]).to(dev), 1
        )
        loss = criterion(classification, labels)
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
        correct += torch.sum(torch.round(classification) == labels) / len(labels)
        t_loss_history.append(sum_loss)
        t_correct_history.append(correct)
        run_count += 1
    if epoch % 250 == 0:
        print(
            f"Train -> Epoch {epoch} / {num_epochs} | Loss : {sum_loss}, Correct : {correct / run_count}"
        )
    model.eval()
    sum_loss = 0.0
    correct = 0
    run_count = 0
    for i, dl in enumerate(test_dataloader, 0):
        data, labels = dl
        outputs = model(data)
        classification = torch.sum(
            outputs * torch.arange(0, outputs.shape[1]).to(dev), 1
        )
        loss = criterion(classification, labels)
        sum_loss += loss.item()
        correct += torch.sum(torch.round(classification) == labels) / len(labels)
        v_loss_history.append(sum_loss)
        v_correct_history.append(correct)
        run_count += 1
    if epoch % 250 == 0:
        print(
            f"Validation -> Epoch {epoch} / {num_epochs} | Loss : {sum_loss}, Correct : {correct / run_count}"
        )
    if correct / run_count > best_val and num_epochs > 3:
        torch.save(
            model.state_dict(),
            f"mmclass_{datetime.now().strftime('%d_%m_%Y__%H_%M_%S')}_.pth",
        )
        best_val = correct / run_count

In [None]:
model.load_state_dict(torch.load("mmclass_02_03_2023__15_28_05_.pth"))

Train -> Epoch 9750 / 10000 | Loss : 0.049242664128541946, Correct : 0.9545454978942871
Validation -> Epoch 9750 / 10000 | Loss : 0.5373633205890656, Correct : 0.7199999690055847

In [None]:
running_results = []
running_labels = []
for i, dl in enumerate(test_dataloader, 0):
    data, labels = dl
    outputs = model(data)
    classification = torch.sum(outputs * torch.arange(0, outputs.shape[1]).to(dev), 1)
    running_results.append(torch.round(classification).tolist())
    running_labels.append(labels.tolist())

In [None]:
conf_labels = np.hstack(running_labels)
conf_results = np.hstack(running_results)

In [None]:
confusion_mat = confusion_matrix(conf_labels, conf_results > 0)

# ⚠️ Health Warning ⚠️
We have way too many cases to train and test the model rigorously.  This is a simple test to see if there was some differentiation in the groups

In [None]:
confusion_mat

In [None]:
# TODO: As in WP3 notebook re confusion matrix rendering
print(confusion_mat[0, :] / np.sum(confusion_mat[0, :]))
print(confusion_mat[1, :] / np.sum(confusion_mat[1, :]))