In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import torch
from data_analysis_preparation.utils import filter_df_by_problem

from dataloader import Mri3DDataLoader
from model import MRINet
from training_loop import evaluate_model, run_training_loop

### Load data

In [2]:
TRAIN_DATA_PATH = "/home/dpolak/alzheimer_disease_classification_tools/tabular_data/train_base.csv"
TEST_DATA_PATH = "/home/dpolak/alzheimer_disease_classification_tools/tabular_data/test_base.csv"
VALID_DATA_PATH = "/home/dpolak/alzheimer_disease_classification_tools/tabular_data/val_base.csv"
BIOCARD_DATA_PATH = "/home/dpolak/alzheimer_disease_classification_tools/tabular_data/biocard_test_set.csv"
MODEL_SAVE_DIRECTORY = "/home/dpolak/alzheimer_disease_classification_tools/models"

In [3]:
train_df = pd.read_csv(TRAIN_DATA_PATH)
test_df = pd.read_csv(TEST_DATA_PATH)
valid_df = pd.read_csv(VALID_DATA_PATH)
biocard_df = pd.read_csv(BIOCARD_DATA_PATH)

### Training and evaluation process

In [4]:
problems = {"AD vs CN": ["AD", "CN"] , "AD vs MCI": ["AD", "MCI"], "MCI vs CN": ["MCI", "CN"], "AD vs MCI vs CN": ["MCI", "CN", "AD"], "p-MCI vs np-MCI": ["p-MCI", "np-MCI"]}
batch_size = 4

In [5]:
histories = {}
results = {problem: {} for problem in problems}
for problem, columns in problems.items():
    print(f"Training {problem}")
    _train_df = Mri3DDataLoader(filter_df_by_problem(problem, train_df), classification_values=columns, batch_size=batch_size)
    _test_df = Mri3DDataLoader(filter_df_by_problem(problem, test_df), classification_values=columns, batch_size=batch_size)
    _valid_df = Mri3DDataLoader(filter_df_by_problem(problem, valid_df), classification_values=columns, batch_size=batch_size)
    if problem != "p-MCI vs np-MCI":
        _biocard_df = Mri3DDataLoader(filter_df_by_problem(problem, biocard_df), classification_values=columns, batch_size=batch_size)
    else:
        _biocard_df = None
    save_directory = Path(MODEL_SAVE_DIRECTORY, problem)
    model = MRINet(num_classes=2 if problem != "AD vs MCI vs CN" else 3)
    model.to("cuda")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    history = run_training_loop(model = model,
                                train_dataset=_train_df,
                                valid_dataset=_valid_df,
                                num_epochs = 20,
                                optimizer = optimizer,
                                batch_size = batch_size,
                                save_directory = save_directory,
                                scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.005, last_epoch=-1))
    histories[problem] = history
    model.load_state_dict(torch.load(Path(save_directory, "best_model.pth")))
    results[problem] = {
        "train": evaluate_model(model, _train_df),
        "test": evaluate_model(model, _test_df),
        "valid": evaluate_model(model, _valid_df),
        "biocard": evaluate_model(model, _biocard_df) if _biocard_df is not None else None
    }

Training AD vs CN


Epoch: 0, LR: [0.001] Step loss: 1.9558913707733154,  Step acc: 2, Train Accuracy: 0.6213592233009708, Balanced Accuracy: 0.6213235294117647, Running Loss: 0.8311814965987668: 100%|██████████| 412/412 [05:37<00:00,  1.22it/s] 
Validation Loss - item: 0.306722491979599, Validation Accuracy: 0.7009803921568627, Balanced accuracy: 0.7072453993641006, Running Loss: 0.6341335323693997: 100%|██████████| 51/51 [00:29<00:00,  1.74it/s]  



Best validation loss: 0.6341335323693997

Saving best model for epoch: 1



Epoch: 1, LR: [0.001] Step loss: 0.23181502521038055,  Step acc: 4, Train Accuracy: 0.7390776699029126, Balanced Accuracy: 0.7386567333454113, Running Loss: 0.541002038197176: 100%|██████████| 412/412 [05:35<00:00,  1.23it/s] 
Validation Loss - item: 0.10794678330421448, Validation Accuracy: 0.7549019607843137, Balanced accuracy: 0.7591290105019751, Running Loss: 0.4565514701546407: 100%|██████████| 51/51 [00:30<00:00,  1.69it/s] 



Best validation loss: 0.4565514701546407

Saving best model for epoch: 2



Epoch: 2, LR: [0.001] Step loss: 0.1466771513223648,  Step acc: 4, Train Accuracy: 0.7815533980582524, Balanced Accuracy: 0.7813108036651952, Running Loss: 0.48410505597995035: 100%|██████████| 412/412 [05:27<00:00,  1.26it/s] 
Validation Loss - item: 0.548092782497406, Validation Accuracy: 0.7598039215686274, Balanced accuracy: 0.7479044223913671, Running Loss: 0.4859967972425854: 100%|██████████| 51/51 [00:29<00:00,  1.73it/s]   
Epoch: 3, LR: [5e-06] Step loss: 0.7777667045593262,  Step acc: 2, Train Accuracy: 0.8246359223300971, Balanced Accuracy: 0.8236592446610608, Running Loss: 0.41099628636125224: 100%|██████████| 412/412 [05:26<00:00,  1.26it/s] 
Validation Loss - item: 0.15974153578281403, Validation Accuracy: 0.8480392156862745, Balanced accuracy: 0.8459870893149629, Running Loss: 0.34673308887902426: 100%|██████████| 51/51 [00:29<00:00,  1.73it/s]



Best validation loss: 0.34673308887902426

Saving best model for epoch: 4



Epoch: 4, LR: [5e-06] Step loss: 0.7941679954528809,  Step acc: 2, Train Accuracy: 0.8507281553398058, Balanced Accuracy: 0.8502783751141485, Running Loss: 0.37012551983510983: 100%|██████████| 412/412 [05:34<00:00,  1.23it/s] 
Validation Loss - item: 0.13270694017410278, Validation Accuracy: 0.8480392156862745, Balanced accuracy: 0.8483957992099431, Running Loss: 0.357341453286947: 100%|██████████| 51/51 [00:29<00:00,  1.72it/s]  
Epoch: 5, LR: [5e-06] Step loss: 0.422732412815094,  Step acc: 4, Train Accuracy: 0.850121359223301, Balanced Accuracy: 0.8497857648907763, Running Loss: 0.371188863379501: 100%|██████████| 412/412 [05:28<00:00,  1.25it/s]     
Validation Loss - item: 0.22280289232730865, Validation Accuracy: 0.8480392156862745, Balanced accuracy: 0.8483957992099431, Running Loss: 0.3544075728631487: 100%|██████████| 51/51 [00:29<00:00,  1.71it/s] 
Epoch: 6, LR: [2.5000000000000002e-08] Step loss: 0.39235633611679077,  Step acc: 3, Train Accuracy: 0.8464805825242718, Balance


Best validation loss: 0.34490293337433947

Saving best model for epoch: 10



 19%|█▉        | 80/412 [00:47<03:17,  1.68it/s]

KeyboardInterrupt



### Results

In [None]:
plt.figure(figsize=(15, 40))
for i, problem in enumerate(problems):
    plt.subplot(10, 2, i*2+1)
    plt.plot(histories[problem]["train_loss"], label="train loss")
    plt.plot(histories[problem]["valid_loss"], label="validation loss")
    plt.title(problem)
    plt.legend()
    plt.subplot(10, 2, i*2+2)
    plt.plot(histories[problem]["train_balanced_accuracy"], label="train balanced accuracy")
    plt.plot(histories[problem]["valid_balanced_accuracy"], label="validation balanced accuracy")
    plt.title(problem)
    plt.legend()
plt.show()
plt.savefig(Path(MODEL_SAVE_DIRECTORY, "training_history.png"))

In [None]:
plt.figure(figsize=(10, 10))
for i, problem in enumerate(problems):
    if problem == "AD vs MCI vs CN":
        continue
    curve = results[problem]["test"]["roc_curve"]
    plt.plot(curve[0][0], curve[0][1], label=f"{problem}, ROC AUC: {results[problem]['test']['roc_auc']}")
plt.plot([0, 1], [0, 1], linestyle='--', label="ROC AUC: 0.5")
plt.legend()
plt.show()

In [None]:
results_dfs = {problem: pd.DataFrame(results[problem]).drop("roc_curve", axis=0, errors="ignore") for problem in problems}

In [None]:
results_dfs["AD vs CN"]

In [None]:
results_dfs["AD vs MCI"]

In [None]:
results_dfs["MCI vs CN"]

In [None]:
results_dfs["AD vs MCI vs CN"]

In [None]:
results_dfs["p-MCI vs np-MCI"]

### Augmentations

In [None]:
augment_dataloader = Mri3DDataLoader(filter_df_by_problem("AD vs CN", train_df), classification_values=["AD", "CN"], batch_size=batch_size, augment=True)

plt.figure(figsize=(20, 20))
plt.title("Augmented images")
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(augment_dataloader.get_single_item(5)[0].cpu().numpy()[0,0,:, :, 80], cmap="gray")
plt.show()

In [None]:
histories = {}
results = {problem: {} for problem in problems}
for problem, columns in problems.items():
    print(f"Training {problem}")
    _train_df = Mri3DDataLoader(filter_df_by_problem(problem, train_df), classification_values=columns, batch_size=batch_size, augment=True)
    _test_df = Mri3DDataLoader(filter_df_by_problem(problem, test_df), classification_values=columns, batch_size=batch_size)
    _valid_df = Mri3DDataLoader(filter_df_by_problem(problem, valid_df), classification_values=columns, batch_size=batch_size)
    if problem != "p-MCI vs np-MCI":
        _biocard_df = Mri3DDataLoader(filter_df_by_problem(problem, biocard_df), classification_values=columns, batch_size=batch_size)
    else:
        _biocard_df = None
    save_directory = Path(MODEL_SAVE_DIRECTORY, problem)
    model = MRINet(num_classes=2 if problem != "AD vs MCI vs CN" else 3)
    model.to("cuda")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    history = run_training_loop(model = model,
                                train_dataset=_train_df,
                                valid_dataset=_valid_df,
                                num_epochs = 10,
                                optimizer = optimizer,
                                batch_size = batch_size,
                                save_directory = save_directory,
                                scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.005, last_epoch=-1))
    histories[problem] = history
    model.load_state_dict(torch.load(Path(save_directory, "best_model.pth")))
    results[problem] = {
        "train": evaluate_model(model, _train_df),
        "test": evaluate_model(model, _test_df),
        "valid": evaluate_model(model, _valid_df),
        "biocard": evaluate_model(model, _biocard_df) if _biocard_df is not None else None
    }

In [None]:
plt.figure(figsize=(15, 40))
for i, problem in enumerate(problems):
    plt.subplot(10, 2, i*2+1)
    plt.plot(histories[problem]["train_loss"], label="train loss")
    plt.plot(histories[problem]["valid_loss"], label="validation loss")
    plt.title(problem)
    plt.legend()
    plt.subplot(10, 2, i*2+2)
    plt.plot(histories[problem]["train_balanced_accuracy"], label="train balanced accuracy")
    plt.plot(histories[problem]["valid_balanced_accuracy"], label="validation balanced accuracy")
    plt.title(problem)
    plt.legend()
plt.show()
plt.savefig(Path(MODEL_SAVE_DIRECTORY, "training_history_augmentation.png"))

In [None]:
results_dfs = {problem: pd.DataFrame(results[problem]).drop("roc_curve", axis=0, errors="ignore") for problem in problems}

In [None]:
results_dfs["AD vs CN"]

In [None]:
results_dfs["AD vs MCI"]

In [None]:
results_dfs["MCI vs CN"]

In [None]:
results_dfs["AD vs MCI vs CN"]

In [None]:
results_dfs["p-MCI vs np-MCI"]