In [None]:
import os
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, f1_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import defaultdict
import random
from pathlib import Path
from google.colab import files

In [None]:
files.upload()
!kaggle datasets download -d vetrik/oasis-alzheimer
!unzip oasis-alzheimer.zip -d ./oasis_alzheimer_data

## Hyperparameters

In [None]:
class Hparams:
    def __init__(self, train_batch_size=64, test_batch_size=64, learning_rate=0.0005, num_epochs=10, val_split=0.1, test_split=0.1, model_path='saved_model', dataset_path='/content/oasis_alzheimer_data/content/Data', seed=42):
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.val_split = val_split
        self.test_split = test_split
        self.model_path = model_path
        self.dataset_path = dataset_path
        self.seed = seed

## Split data set

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from sklearn.model_selection import GroupKFold, train_test_split

class CustomDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

        # Calculate mean and std
        self.means = []
        self.stds = []
        for i in range(len(subset)):
            x, _ = subset[i]
            if self.transform:
                x = self.transform(x)
            self.means.append(torch.mean(x))
            self.stds.append(torch.std(x))

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, self.means[index], self.stds[index], y

    def __len__(self):
        return len(self.subset)

def get_transforms():
    return transforms.Compose([
        transforms.Resize((248, 248)),
        transforms.ToTensor(),
    ])

def get_sample_weights(targets, indices):
    """Calculate sample weights for imbalanced classes"""
    y_train = [targets[i] for i in indices]
    class_sample_counts = np.array([len(np.where(y_train == t)[0]) for t in np.unique(y_train)])
    weights = 1. / class_sample_counts
    sample_weights = np.array([weights[t] for t in y_train])
    return torch.from_numpy(sample_weights).double()

def get_subject_ids_from_paths(image_paths):
    """Extract subject IDs from image paths for OASIS dataset."""
    subject_ids = []
    for path in image_paths:
        # split the path into components
        parts = path.split(os.path.sep)
        # extract filename
        filename = parts[-1]
        # extract the subject ID from filename
        subject_id = filename.split('_')[1]
        subject_ids.append(subject_id)
    return subject_ids

def get_data_loaders(hparams):
    # load dataset and get subject id
    dataset = datasets.ImageFolder(hparams.dataset_path,
                                 transform=transforms.Compose([transforms.Grayscale()]))

    # get subject id
    image_paths = [img[0] for img in dataset.imgs]
    subject_ids = get_subject_ids_from_paths(image_paths)

    # first split: separate test set by subjects
    subjects = np.unique(subject_ids)
    train_val_subjects, test_subjects = train_test_split(
        subjects,
        test_size=hparams.test_split,
        random_state=hparams.seed
    )

    # second split: separate train/val from remaining subjects
    train_subjects, val_subjects = train_test_split(
        train_val_subjects,
        test_size=hparams.val_split/(1-hparams.test_split),
        random_state=hparams.seed
    )

    # create index masks based on subject split
    train_indices = [i for i, subj in enumerate(subject_ids) if subj in train_subjects]
    val_indices = [i for i, subj in enumerate(subject_ids) if subj in val_subjects]
    test_indices = [i for i, subj in enumerate(subject_ids) if subj in test_subjects]

    # create subsets
    train_subset = torch.utils.data.Subset(dataset, train_indices)
    val_subset = torch.utils.data.Subset(dataset, val_indices)
    test_subset = torch.utils.data.Subset(dataset, test_indices)

    print(f"Total samples: {len(dataset)}")
    print(f"Class distribution: {np.bincount(dataset.targets)}")
    print(f"Training subjects: {len(train_subjects)}, Val subjects: {len(val_subjects)}, Test subjects: {len(test_subjects)}")
    print(f"Train size: {len(train_subset)}, Val size: {len(val_subset)}, Test size: {len(test_subset)}")

    # apply transforms
    data_transforms = get_transforms()

    # create custom datasets with precomputed stats
    train_dataset = CustomDataset(train_subset, transform=data_transforms)
    val_dataset = CustomDataset(val_subset, transform=data_transforms)
    test_dataset = CustomDataset(test_subset, transform=data_transforms)

    # create samplers
    sample_weights = get_sample_weights(dataset.targets, train_indices)
    train_sampler = torch.utils.data.WeightedRandomSampler(
        sample_weights,
        len(sample_weights),
        replacement=True
    )

    # create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams.train_batch_size,
        sampler=train_sampler,
        drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=hparams.train_batch_size,
        drop_last=True
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=hparams.test_batch_size
    )

    return train_loader, val_loader, test_loader


Model ConvMixer

In [None]:
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.metrics import f1_score
from torchvision import datasets, transforms
from tqdm import tqdm


class PatchExtractor(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        assert height % self.patch_size == 0 and width % self.patch_size == 0, \
            "Image dimensions must be divisible by the patch size"

        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        num_patches = num_patches_h * num_patches_w

        # extract patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        patches = patches.view(batch_size, num_patches, channels * self.patch_size * self.patch_size)
        return patches


class ConvMixerBlock(nn.Module):
    def __init__(self, dim, kernel_size=9):
        super().__init__()
        self.residual = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size, padding="same", groups=dim),
            nn.GELU(),
            nn.BatchNorm2d(dim),
        )
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, dim * 4, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim * 4),
            nn.Conv2d(dim * 4, dim, kernel_size=1),
            nn.BatchNorm2d(dim),
        )

    def forward(self, x):
        residual = self.residual(x)
        x = x + residual
        x = x + self.mlp(x)
        return x


class ConvMixer(nn.Module):
    def __init__(self, num_classes=4, patch_size=16, dim=256, depth=8, image_size=240):
        super().__init__()
        self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
        self.convmixer_blocks = nn.Sequential(*[ConvMixerBlock(dim) for _ in range(depth)])
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(dim, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, img, mean, std):
        x = self.patch_embedding(img)
        x = self.convmixer_blocks(x)
        x = self.pooling(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.softmax(x)
        return x


def train(model, train_loader, criterion, optimizer, device, epoch, num_epochs):
    model.train()
    train_loss = 0
    train_correct = 0
    batch_size = 0

    targets, preds = [], []

    for batch_idx, (img, mean, std, target) in train_loader:
        img, mean, std, target = img.to(device), mean.to(device), std.to(device), target.to(device)
        batch_size = len(img)

        optimizer.zero_grad()
        output = model(img, mean, std)
        loss = criterion(output, target)
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)

        targets.append(target.cpu().numpy())
        preds.append(pred.cpu().numpy().flatten())

        train_correct += pred.eq(target.view_as(pred)).sum().item()
        loss.backward()
        optimizer.step()

        train_loader.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        train_loader.set_postfix(loss=train_loss / ((batch_idx+1) * len(img)), accuracy=100. * train_correct / ((batch_idx+1) * len(img)))

    targets = np.concatenate(targets)
    preds = np.concatenate(preds)
    f1 = f1_score(targets, preds, average='macro')

    train_length = train_loader.total * batch_size
    train_loss /= train_length
    train_accuracy = 100. * train_correct / train_length
    return train_loss, train_accuracy, f1


def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    val_correct = 0
    total_size = 0
    with torch.no_grad():
        for batch_idx, (img, mean, std, target) in enumerate(val_loader):
            img, mean, std, target = img.to(device), mean.to(device), std.to(device), target.to(device)
            batch_size = len(img)
            output = model(img, mean, std)
            loss = criterion(output, target)
            val_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            val_correct += pred.eq(target.view_as(pred)).sum().item()

            total_size += len(img)
    val_loss /= total_size
    val_accuracy = 100. * val_correct / total_size
    return val_loss, val_accuracy


def predict(model, data_loader, criterion, device, eval=False):
    model.eval()
    pred_loss = 0
    pred_correct = 0
    total_size = 0

    predictions = torch.IntTensor()
    ground_truths = torch.IntTensor()

    predictions, ground_truths = predictions.to(device), ground_truths.to(device)

    with torch.no_grad():
        for batch_idx, (img, mean, std, target) in enumerate(data_loader):
            img, mean, std, target = img.to(device), mean, std.to(device), target.to(device)
            output = model(img, mean, std)
            loss = criterion(output, target)
            pred_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            pred_correct += pred.eq(target.view_as(pred)).sum().item()

            predictions = torch.cat((predictions, pred), dim=0)
            ground_truths = torch.cat((ground_truths, target), dim=0)

            total_size += len(img)

    pred_loss /= total_size
    pred_accuracy = 100. * pred_correct / total_size

    if eval:
        return pred_loss, pred_accuracy, predictions.cpu().numpy(), ground_truths.cpu().numpy()
    else:
        return predictions.cpu().numpy(), ground_truths.cpu().numpy()


def train_and_validate(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, early_stopping=None):
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader), total=len(train_loader), ascii=' >=')
        train_loss, train_accuracy, f1 = train(model, loop, criterion, optimizer, device, epoch, num_epochs)
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        val_loss, val_accuracy = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        if early_stopping is not None:
            early_stopping(val_accuracy)

            if early_stopping.early_stop:
                tqdm.write(f'\t => train_f1={f1:.4f}, val_loss={val_loss:.4f}, val_acc={val_accuracy:.4f}')
                print(f'Early stopping at Epoch {epoch+1}')
                break

        tqdm.write(f'\t => train_f1={f1:.4f}, val_loss={val_loss:.4f}, val_acc={val_accuracy:.4f}')

    plot_losses(train_losses, val_losses)
    plot_accuracies(train_accuracies, val_accuracies)

    return train_losses, train_accuracies, val_losses, val_accuracies

class EarlyStopping:
    def __init__(self, patience=5, mode='max'):
        self.counter = 0
        self.patience = patience
        self.early_stop = False
        self.mode = mode

        if self.mode == 'max':
            self.ref_value = float('-inf')
        elif self.mode == 'min':
            self.ref_value = float('inf')
        else:
            raise Exception(f'Undefined mode for EarlyStopping - mode: {mode}\n'
                             'Available modes are ["max", "min"]')

    def __call__(self, value):
        if self.mode == 'max':
            if value <= self.ref_value:
                self.counter += 1
            else:
                self.counter = 0
                self.ref_value = value
        elif self.mode == 'min':
            if value >= self.ref_value:
                self.counter += 1
            else:
                self.counter = 0
                self.ref_value = value

        if self.counter == self.patience:
            self.early_stop = True



## Visualize

In [None]:
def plot_losses(train_losses, val_losses):
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_accuracies(train_accuracies, val_accuracies):
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

## Training and Evaluation

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")
hparams = Hparams()
train_loader, val_loader, test_loader = get_data_loaders(hparams)
early_stopping = EarlyStopping(patience=10, mode='max')

In [7]:
print(f"The train_loader contains {len(train_loader)} batches.")
print(f"The val_loader contains {len(val_loader)} batches.")
print(f"The test_loader contains {len(test_loader)} batches.")

### Training

In [None]:
model = ConvMixer().to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hparams.learning_rate)

In [2]:
train_losses, train_accuracies, val_losses, val_accuracies = train_and_validate(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=hparams.num_epochs, early_stopping=early_stopping)

In [3]:
plot_losses(train_losses, val_losses)
plot_accuracies(train_accuracies, val_accuracies)

### Testing

In [4]:
test_loss, test_accuracy = validate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [5]:
def create_confusion_matrix(model, data_loader, device, class_names):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, _, _, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images, None, None)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    # generate classification report
    print(classification_report(all_targets, all_preds, target_names=class_names))

    # generate confusion matrix
    cm = confusion_matrix(all_targets, all_preds)

    # calculate sensitivity for each class
    sensitivity = []
    for i in range(len(class_names)):
        tp = cm[i, i]
        fn = np.sum(cm[i, :]) - tp
        sens = tp / (tp + fn) if (tp + fn) > 0 else 0  # handle division by zero
        sensitivity.append(sens)
        print(f"Sensitivity for class {class_names[i]}: {sens:.4f}")

    # convert the confusion matrix to percentages
    cm_percent = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]) * 100

    df_cm = pd.DataFrame(cm_percent, index=class_names, columns=class_names)

    # plot confusion matrix
    plt.figure(figsize=(10, 8))
    sn.heatmap(df_cm, annot=True, fmt='.1f', cmap='Blues',
               annot_kws={"size": 12}, vmin=0, vmax=100, cbar=True)

    plt.title('Normalized Confusion Matrix - ConvMixer', fontsize=16)
    plt.xlabel('Predicted', fontsize=14)
    plt.ylabel('Actual', fontsize=14, rotation=90, va="center")
    plt.xticks(rotation=0, ha='center', fontsize=12)
    plt.yticks(rotation=0, fontsize=12)
    plt.tight_layout()
    plt.show()


# extract class names from the dataset
class_names = train_loader.dataset.subset.dataset.classes

create_confusion_matrix(model, test_loader, device, class_names)