# Brain Disease Diagnosis with Classifier

## Setup imports

In [None]:
import torch
import numpy as np
import glob
import os
import pandas as pd
import logging
import time
import matplotlib.pyplot as plt
%matplotlib inline

from monai.config import print_config
from monai.data import Dataset, DataLoader
from monai.networks.nets import Classifier
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    SpatialCrop,
    Resize,
    NormalizeIntensity,
)
from monai.utils import first, set_determinism

from sklearn.model_selection import train_test_split
import torchinfo

print_config()

## Set deterministic training for reproducibility

In [16]:
set_determinism(seed=0)

## Setup directories and data

In [None]:
root_dir = "C:\\BrainDiseaseDiagnosis\\Brain"
print(root_dir)
model_dir = os.path.join(root_dir, "Classifier_MultiClass_Batch20_LR5")
os.makedirs(model_dir, exist_ok=True)
images = sorted(glob.glob(os.path.join(root_dir, "train", "*.nii.gz")))
df = pd.read_csv(os.path.join(root_dir, "group_train.csv"))

def groupname(groupidx):
    group = {0: "Healthy Control", 1: "Parkinson's Disease", 2: "Prodromal Parkinson's Disease"}.get(groupidx)
    return group

## Setup logging

In [18]:
log_file = os.path.join(model_dir, "brain_disease_diagnosis.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(asctime)s -  %(message)s")
logger = logging.getLogger()

## Setup transforms and dataset

In [None]:
# Calculate class weights
class ImageLabelWeightDataset(Dataset):
    def __init__(self, img, labels=None, img_transform=None):
        self.img = img
        self.labels = labels
        self.img_transform = img_transform
        if self.labels is not None:
            self.compute_classweight()

    def compute_classweight(self):
        class_count = np.bincount(self.labels)
        total_count = len(self.labels)
        self.classweight = total_count / (len(class_count) * class_count)

    def __len__(self) -> int:
        return len(self.img)

    def __getitem__(self, index):
        img = self.img[index]
        if self.img_transform:
            img = self.img_transform(img)
        if self.labels is not None:
            label = self.labels[index]
            weight = self.classweight[label]
            return img, label, weight
        else:
            return img

batch_size = 20

# Define transforms for image
imtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        SpatialCrop(roi_center=(84, 102, 84), roi_size=(160, 192, 160)),
        Resize((64, 64, 64), mode="trilinear"),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ]
)

ds = ImageLabelWeightDataset(img=images, labels=df["Group"].values, img_transform=imtrans)
train_ds, val_ds = train_test_split(ds, stratify=df["Group"].values, test_size=0.2)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())

# Check data shape
tr = first(train_loader)
print(f"training: ({list(tr[0].shape)}, {list(tr[1].shape)}) \u00D7 {len(train_loader)}")
vl = first(val_loader)
print(f"validation: ({list(vl[0].shape)}, {list(vl[1].shape)}) \u00D7 {len(val_loader)}")

## Check data shape and visualize

In [None]:
sliceidx = 30
fig = plt.figure("Example image for training", (12, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
ax.set_title(f"Group: {groupname(tr[1][0].item())}")
ax.imshow(np.rot90(tr[0][0,0,:, :, sliceidx].detach().cpu()), cmap="gray")
ax.axis('off')
plt.savefig(os.path.join(model_dir, "image_group.tif"), dpi=300)
plt.show

## Create model

In [21]:
max_epochs = 100
val_interval = 1
lr = 1e-4

# Create Classifier, CrossEntropyLoss, and Adam optimizer
device = torch.device("cuda")
# device = torch.device("mps")
# device = torch.device("cpu")
model = Classifier(
    in_shape=[1, 64, 64, 64],
    classes=3,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
    kernel_size=3,
    num_res_units=2,
).to(device)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    
# Use AMP to accelerate training
if torch.cuda.is_available():
    scaler = torch.cuda.amp.GradScaler()

## Print model

In [None]:
torchinfo.summary(model, input_size=(batch_size, 1, 64, 64, 64))

## Execute training process

In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
epoch_metric_values = []
metric_values = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    num_correct = 0.0
    metric_count = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        im, val, weight = batch_data
        inputs, labels, sample_weights = (
            im.to(device),
            val.to(device),
            weight.to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        weighted_loss = (loss * sample_weights).mean()  # Apply class weights
        if torch.cuda.is_available():
            scaler.scale(weighted_loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            weighted_loss.backward()
            optimizer.step()
        epoch_loss += weighted_loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
        
        value = torch.eq(outputs.argmax(dim=1), labels)
        metric_count += len(value)
        num_correct += value.sum().item()
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    metric = num_correct / metric_count
    epoch_metric_values.append(metric)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.0
            metric_count = 0
            for val_data in val_loader:
                im, val, weight = val_data
                val_inputs, val_labels = (
                    im.to(device),
                    val.to(device),
                )
                val_outputs = model(val_inputs)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(model_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current accuracy: {metric:.4f}"
                f"\nbest accuracy: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
            logger.info(f"epoch {epoch + 1} accuracy: {metric:.4f}")
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")

## Plot loss and metric

In [None]:
total_time = time.time() - total_start
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")
logger.info(
    f"best_metric: {best_metric:.4f} at epoch {best_metric_epoch}, "
    f"total time to train: {total_time}"
    )

fig = plt.figure("Performance in training", (12, 6))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
ax1.plot(x, y, color="red")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("Accuracy")
x1 = [i + 1 for i in range(len(epoch_loss_values))]
x2 = [val_interval * (i + 1) for i in range(len(metric_values))]
y1 = epoch_metric_values
y2 = metric_values
ax2.plot(x1, y1, color="red")
ax2.plot(x2, y2, color="blue")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.legend(["Train","Validation"])
plt.savefig(os.path.join(model_dir, "performance.tif"), dpi=300)
plt.show

## Check best model output

In [None]:
testidx = 3
model.load_state_dict(torch.load(os.path.join(model_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # Select one image to evaluate and visualize the model output
    val_input = val_ds[testidx][0].unsqueeze(0).to(device)
    val_output = model(val_input)

fig = plt.figure("Actual vs. Predicted", (12, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
ax.set_title(f"Actual group = {groupname(val_ds[testidx][1].item())}"
             f"\nPredicted group = {groupname(val_output.argmax(dim=1).item())}")
ax.imshow(np.rot90(val_ds[testidx][0][0,:, :, 30].detach().cpu()), cmap="gray")
ax.axis('off')
plt.savefig(os.path.join(model_dir, "actual_predicted.tif"), dpi=300)
plt.show

## Apply best model

In [37]:
# Define NIfTI dataset, dataloader
test_images = sorted(glob.glob(os.path.join(root_dir, "test", "*.nii.gz")))
test_ds = ImageLabelWeightDataset(img=test_images, img_transform=imtrans)

# Apply the best model and save predictions
model.load_state_dict(torch.load(os.path.join(model_dir, "best_metric_model.pth")))
model.eval()
test_predictions = []

with torch.no_grad():
    for idx in range(len(test_ds)):
        test_input = test_ds[idx].unsqueeze(0).to(device)
        test_output = model(test_input)
        test_predictions.append(test_output.argmax(dim=1).item())

np.savetxt(os.path.join(model_dir, "BrainDiseaseDiagnosis.txt"), test_predictions)